├── run
└── semantic_parsing_with_constrained_lm
│ ├── domains
│ ├── sql
│ │ ├── grammar
│ │ │ ├── start.scfg
│ │ │ ├── table_columns.scfg
│ │ │ ├── whitespace.scfg
│ │ │ └── quoted.scfg
│ │ ├── __init__.py
│ │ ├── cosql
│ │ │ ├── __init__.py
│ │ │ ├── paths.py
│ │ │ ├── seq2seq.py
│ │ │ ├── grammar.py
│ │ │ └── schema.py
│ │ ├── sql_datum.py
│ │ ├── sequence_creator.py
│ │ ├── sql_metric.py
│ │ └── create_benchclamp_data.py
│ ├── calflow
│ │ ├── grammar
│ │ │ ├── start.scfg
│ │ │ ├── entities.scfg
│ │ │ ├── enum_wrappers.scfg
│ │ │ ├── quoted.scfg
│ │ │ └── fluenter.scfg
│ │ ├── data
│ │ │ └── ids_dev_100_uniform.txt
│ │ └── __init__.py
│ ├── __init__.py
│ ├── mtop
│ │ ├── __init__.py
│ │ └── grammar.py
│ ├── lispress_v2
│ │ ├── __init__.py
│ │ ├── sequence_creator.py
│ │ └── grammar.py
│ ├── ltl
│ │ ├── data
│ │ │ ├── pick-golden-cross0-split0.canonical.json
│ │ │ ├── pick-golden-cross0-split1.canonical.json
│ │ │ ├── pick-golden-cross0-split2.canonical.json
│ │ │ ├── pick-golden-cross0-split3.canonical.json
│ │ │ ├── pick-golden-cross0-split4.canonical.json
│ │ │ ├── pick-syn.canonical.json
│ │ │ ├── pick-syn-aug.canonical.json
│ │ │ └── pick-syn.train.jsonl
│ │ ├── kept
│ │ │ ├── pick-syn.canonical.json
│ │ │ └── pick-syn.train.jsonl
│ │ ├── create_benchclamp_data.py
│ │ └── __init__.py
│ ├── dfa_grammar_utils.py
│ ├── overnight
│ │ ├── create_benchclamp_data.py
│ │ └── __init__.py
│ ├── calflow_eval_utils.py
│ └── create_benchclamp_splits.py
│ ├── __init__.py
│ ├── scfg
│ ├── __init__.py
│ ├── parser
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── types.py
│ │ ├── scfg_grammar.lark
│ │ ├── macro.py
│ │ ├── token.py
│ │ ├── rule.py
│ │ └── parse.py
│ ├── string_utils.py
│ └── earley_grammar.py
│ ├── configs
│ ├── __init__.py
│ └── lib
│ │ ├── __init__.py
│ │ ├── benchclamp.py
│ │ └── calflow.py
│ ├── decoding
│ ├── __init__.py
│ ├── trie_partial_parse.py
│ ├── partial_parse.py
│ └── uint8_earley_partial_parse.py
│ ├── earley
│ ├── __init__.py
│ ├── cfg.lark
│ ├── unicode_categories_spans.py
│ ├── context_sensitive.py
│ ├── specialization.py
│ └── recognize.py
│ ├── finetune
│ ├── __init__.py
│ ├── configs
│ │ └── __init__.py
│ ├── download_huggingface_lms.py
│ └── check_for_unks.py
│ ├── scripts
│ ├── __init__.py
│ └── calflow_fit_max_steps.py
│ ├── async_tools
│ ├── __init__.py
│ └── batch_helper.py
│ ├── index
│ ├── __init__.py
│ ├── index.py
│ └── bm25_index.py
│ ├── util
│ ├── types.py
│ ├── unit.py
│ ├── keydefaultdict.py
│ └── logger.py
│ ├── sequence_creator.py
│ ├── cache.py
│ ├── playground.ipynb
│ ├── result.py
│ ├── paths.py
│ ├── datum.py
│ ├── train_model_setup.py
│ ├── fit_max_steps.py
│ └── eval.py
├── env_install.sh
├── docs
├── _config.yml
├── _layouts
│ └── default.html
└── index.md
├── datasets
└── pick-and-place
│ ├── canonical.json
│ └── train_seed.jsonl
└── README.md
/run/semantic_parsing_with_constrained_lm/domains/sql/grammar/start.scfg:
--------------------------------------------------------------------------------
1 | start -> parse
2 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/start.scfg:
--------------------------------------------------------------------------------
1 | start -> " " unit , unit
2 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/decoding/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/earley/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/finetune/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/async_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/configs/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/mtop/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/lispress_v2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/cosql/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/finetune/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/index/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .index import Candidate, DynamicIndex, Index, Query
5 |
--------------------------------------------------------------------------------
/env_install.sh:
--------------------------------------------------------------------------------
1 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
2 | pip install transformers datasets
3 | pip install jsons appdirs blobfile cached-property httpx typer whoosh more_itertools jupyter openai
4 | pip install --upgrade protobuf==3.20.0
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | def is_skippable(string: str):
5 | """A string is skippable if it's empty or begins with a '#'"""
6 | return not string or string[0] == "#"
7 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | remote_theme: pages-themes/cayman@v0.2.0
2 | plugins:
3 | - jekyll-remote-theme # add this line to the plugins list if you already have one
4 | title: Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification
5 | description: "Jiayi Pan, Glen Chou, Dmitry Berenson"
6 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/util/types.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os # pylint: disable=unused-import
5 | from typing import Union
6 |
7 | # This can be used to annotate arguments that are supposed to be file paths.
8 | StrPath = Union[str, "os.PathLike[str]"]
9 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/sql_datum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from dataclasses import dataclass
5 |
6 | from semantic_parsing_with_constrained_lm.datum import FullDatum
7 |
8 |
9 | @dataclass(frozen=True)
10 | class SqlDatum(FullDatum):
11 | schema_name: str
12 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/types.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Tuple
5 |
6 | from semantic_parsing_with_constrained_lm.scfg.parser.token import SCFGToken
7 |
8 | Nonterminal = str
9 | # An Alias is just another name for a nonterminal.
10 | Alias = str
11 |
12 |
13 | Expansion = Tuple[SCFGToken, ...]
14 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/mtop/grammar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import functools
5 |
6 | from semantic_parsing_with_constrained_lm.domains import dfa_grammar_utils
7 |
8 | create_partial_parse_builder = functools.partial(
9 | dfa_grammar_utils.create_partial_parse_builder,
10 | utterance_nonterm_name="any_char_star",
11 | )
12 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/util/unit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | class Unit:
5 | """
6 | Analogue of Scala Unit type that has a single instance UNIT. Can be used as type
7 | placeholder. Similar to None but can be used where None doesn't work.
8 | """
9 |
10 | def __repr__(self) -> str:
11 | return "Unit"
12 |
13 |
14 | UNIT = Unit()
15 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/entities.scfg:
--------------------------------------------------------------------------------
1 | personname -> quoted , " #(PersonName " quoted ")"
2 | string -> quoted , " #(String " quoted ")"
3 | respondcomment -> quoted , " #(RespondComment " quoted ")"
4 | locationkeyphrase -> quoted , " #(LocationKeyphrase " quoted ")"
5 | path -> quoted , " #(Path " quoted ")"
6 |
7 | list_path_ -> "(empty list)" , " #(List[Path] [])"
8 | list_recipient_ -> "(empty recipient list)" , " #(List[Recipient] [])"
9 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/enum_wrappers.scfg:
--------------------------------------------------------------------------------
1 | holiday -> holiday_entity , " #(Holiday \"" holiday_entity "\")"
2 | placefeature -> place_feature_entity , " #(PlaceFeature \"" place_feature_entity "\")"
3 | weatherquantifier -> weather_quantifier_entity , " #(WeatherQuantifier \"" weather_quantifier_entity "\")"
4 | responsestatustype -> response_entity , " #(ResponseStatusType \"" response_entity "\")"
5 | number -> number_entity , " #(Number" number_entity ")"
6 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/grammar/table_columns.scfg:
--------------------------------------------------------------------------------
1 | ## these get auto-gen'd per schema
2 |
3 | # schema_name -> any_name
4 | # table_name -> any_name
5 | # table_alias -> any_name
6 | # column_name -> any_name
7 | # schema_table_dot -> schema_dot_ws? table_dot
8 | # possibly_qualified_column_name -> schema_table_dot? column_name
9 |
10 | table_alias -> T "1"
11 | table_alias -> T "2"
12 | table_alias -> T "3"
13 | table_alias -> T "4"
14 | table_alias -> T "5"
15 |
16 | table_name -> table_alias
17 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/cosql/paths.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from pathlib import Path
5 | from semantic_parsing_with_constrained_lm.paths import DOMAINS_DIR
6 |
7 |
8 | # TODO: Don't hardcode
9 | COSQL_DIR = Path("data/cosql_dataset")
10 |
11 | SCHEMAS_FILE = COSQL_DIR / "tables.json"
12 |
13 | SQL_STATE_TRACKING_DIR = COSQL_DIR / "sql_state_tracking"
14 | TRAIN_DATA_FILE = SQL_STATE_TRACKING_DIR / "cosql_train.json"
15 |
16 | SQL_GRAMMAR_DIR = DOMAINS_DIR / "sql" / "grammar"
17 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/sequence_creator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from abc import ABC, abstractmethod
5 |
6 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum
7 |
8 |
9 | class SequenceCreator(ABC):
10 | @abstractmethod
11 | def create_sequence(self, datum: BenchClampDatum) -> str:
12 | pass
13 |
14 |
15 | class IdentitySequenceCreator(SequenceCreator):
16 | def create_sequence(self, datum: BenchClampDatum) -> str:
17 | return datum.utterance
18 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/cache.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import Optional
6 |
7 |
8 | class CacheClient(ABC):
9 | async def __aenter__(self):
10 | pass
11 |
12 | async def __aexit__(self, exc_type, exc_value, traceback):
13 | pass
14 |
15 | @abstractmethod
16 | async def get(self, args: dict) -> Optional[dict]:
17 | pass
18 |
19 | @abstractmethod
20 | async def upload(self, args: dict, result: dict) -> None:
21 | pass
22 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/quoted.scfg:
--------------------------------------------------------------------------------
1 | # Quoted strings *do* begin with a space in this grammar.
2 | # For example, `create event with " Rose"`.
3 | # The space has to be a regex, b/c it gets consumed by CopyTokens,
4 | # and it has to not be inside nonquoteplus, because it doesn't
5 | # appear on the plan side.
6 | quoted -> "\"" / / nonquoteplus "\"" , "\"" nonquoteplus "\""
7 |
8 | # matches one or more characters that are not double quotes
9 | nonquoteplus -> /[^"]/ nonquotestar
10 |
11 | # matches zero or more characters that are not double quotes
12 | nonquotestar -> /[^"]/ nonquotestar
13 | nonquotestar -> empty
14 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/string_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Iterable
5 |
6 |
7 | def detokenize(tokens: Iterable[str], with_treebank: bool = True) -> str:
8 | """
9 | Given a list of tokens, join them together into a string.
10 | with_treebank = True is typically used when rendering utterances, so we don't need to deal with things like
11 | "andrew's"
12 | with_treebank = False is typically for rendering express.
13 | """
14 | if with_treebank:
15 | return " ".join(tokens).replace(" ", " ")
16 |
17 | return "".join(tokens)
18 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/playground.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": []
9 | }
10 | ],
11 | "metadata": {
12 | "kernelspec": {
13 | "display_name": "Python 3.10.4 64-bit",
14 | "language": "python",
15 | "name": "python3"
16 | },
17 | "language_info": {
18 | "name": "python",
19 | "version": "3.10.4"
20 | },
21 | "orig_nbformat": 4,
22 | "vscode": {
23 | "interpreter": {
24 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
25 | }
26 | }
27 | },
28 | "nbformat": 4,
29 | "nbformat_minor": 2
30 | }
31 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/grammar/whitespace.scfg:
--------------------------------------------------------------------------------
1 | # This subgrammar matches (non-newline) whitespace sequences on the left side,
2 | # and returns a single " " on the right side
3 | # Newlines are turned into "\n"
4 |
5 | # 1 or more whitespace chars, returns SPACE
6 | ws -> ws_char ws_star_empty
7 |
8 | # 0 or more whitespace chars, returns SPACE
9 | ws_star -> #e , SPACE
10 | ws_star -> ws
11 |
12 | # matches any whitespace sequence and returns #e
13 | ws_star_empty -> ws_char ws_star , #e
14 | ws_star_empty -> #e
15 |
16 |
17 | SPACE -> " "
18 |
19 | # matches any whitespace char and returns SPACE
20 | ws_char -> SPACE
21 | ws_char -> "\u000B", SPACE
22 | ws_char -> TAB , SPACE
23 | ws_char -> NEWLINE , SPACE
24 |
25 | NEWLINE -> "\r" , "\n"
26 | NEWLINE -> "\n"
27 |
28 | TAB -> "\t"
29 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/index/index.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import Generic, Iterable, Tuple, TypeVar
6 |
7 | Key = TypeVar("Key")
8 | Query = TypeVar("Query")
9 | Candidate = TypeVar("Candidate")
10 |
11 |
12 | class Index(Generic[Key, Query], ABC):
13 | """
14 | Encapsulates any index that can be searched over.
15 | It can either be a sparse index (e.g. powered by Whoosh), or a dense index (e.g. powered by FAISS).
16 | """
17 |
18 | @abstractmethod
19 | def search(self, query: Query, top_k: int) -> Iterable[Tuple[Key, float]]:
20 | raise NotImplementedError
21 |
22 |
23 | class DynamicIndex(Generic[Key, Query, Candidate], Index[Key, Query]):
24 | """
25 | Any index that supports dynamic addition to the set of candidates.
26 | """
27 |
28 | @abstractmethod
29 | def add(self, candidates: Iterable[Candidate]) -> None:
30 | raise NotImplementedError
31 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/decoding/trie_partial_parse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import dataclasses
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 |
10 | from semantic_parsing_with_constrained_lm.util.trie import Trie
11 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse
12 |
13 |
14 | @dataclass
15 | class TriePartialParse(PartialParse):
16 | trie: Trie[int]
17 | tokens: Tuple[int, ...] = ()
18 |
19 | def allowed_next(
20 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None
21 | ) -> Tuple[torch.Tensor, bool]:
22 | allowed, is_complete = self.trie.prefix_next(self.tokens)
23 | return torch.tensor(sorted(allowed), dtype=torch.long), is_complete
24 |
25 | def append(self, token: int) -> "PartialParse":
26 | """Return a new PartialParse creatoted by appending this token."""
27 | return dataclasses.replace(self, tokens=self.tokens + (token,))
28 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split0.canonical.json:
--------------------------------------------------------------------------------
1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}}
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split1.canonical.json:
--------------------------------------------------------------------------------
1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}}
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split2.canonical.json:
--------------------------------------------------------------------------------
1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}}
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split3.canonical.json:
--------------------------------------------------------------------------------
1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}}
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split4.canonical.json:
--------------------------------------------------------------------------------
1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}}
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/lispress_v2/sequence_creator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum
5 | from semantic_parsing_with_constrained_lm.sequence_creator import SequenceCreator
6 |
7 |
8 | class LastAgentUtterance(SequenceCreator):
9 | def create_sequence(self, datum: BenchClampDatum) -> str:
10 | last_agent_utterance = (
11 | datum.last_agent_utterance if datum.last_agent_utterance is not None else ""
12 | )
13 | return " | ".join([last_agent_utterance, datum.utterance])
14 |
15 |
16 | class LastUserAgentUtterance(SequenceCreator):
17 | def create_sequence(self, datum: BenchClampDatum) -> str:
18 | last_agent_utterance = (
19 | datum.last_agent_utterance if datum.last_agent_utterance is not None else ""
20 | )
21 | last_user_utterance = (
22 | datum.last_user_utterance if datum.last_user_utterance is not None else ""
23 | )
24 | return " | ".join([last_user_utterance, last_agent_utterance, datum.utterance])
25 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-syn.canonical.json:
--------------------------------------------------------------------------------
1 | {
2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {
3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"
4 | },
5 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {
6 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"
7 | },
8 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {
9 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"
10 | },
11 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {
12 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"
13 | },
14 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {
15 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"
16 | }
17 | }
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/kept/pick-syn.canonical.json:
--------------------------------------------------------------------------------
1 | {
2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {
3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"
4 | },
5 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {
6 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"
7 | },
8 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {
9 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"
10 | },
11 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {
12 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"
13 | },
14 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {
15 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"
16 | }
17 | }
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/finetune/download_huggingface_lms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from pathlib import Path
5 |
6 | from transformers import (
7 | BartForConditionalGeneration,
8 | BartTokenizer,
9 | RobertaTokenizer,
10 | T5ForConditionalGeneration,
11 | T5Tokenizer,
12 | )
13 |
14 | CLAMP_PRETRAINED_MODEL_DIR = Path("huggingface_models/")
15 |
16 |
17 | def save_model_and_tokenizer(model, tokenizer, save_dir: Path) -> None:
18 | save_dir.mkdir(exist_ok=True, parents=True)
19 | model.save_pretrained(save_dir)
20 | tokenizer.save_pretrained(save_dir)
21 |
22 |
23 | def main():
24 | # T5
25 | # Bart
26 | for model_id, huggingface_model_id in [
27 | ("bart-large", "facebook/bart-large"),
28 | ]:
29 | print(f"Downloading {model_id} ...")
30 | model = BartForConditionalGeneration.from_pretrained(huggingface_model_id)
31 | tokenizer = BartTokenizer.from_pretrained(huggingface_model_id)
32 | save_model_and_tokenizer(
33 | model, tokenizer, CLAMP_PRETRAINED_MODEL_DIR / model_id
34 | )
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-syn-aug.canonical.json:
--------------------------------------------------------------------------------
1 | {
2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {
3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"
4 | },
5 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {
6 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"
7 | },
8 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {
9 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"
10 | },
11 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {
12 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"
13 | },
14 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {
15 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"
16 | }
17 | }
--------------------------------------------------------------------------------
/datasets/pick-and-place/canonical.json:
--------------------------------------------------------------------------------
1 | {
2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {
3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )",
4 | "raw": "G & U S ! A F A"
5 | },
6 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {
7 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )",
8 | "raw": "G & U S ! R F R"
9 | },
10 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {
11 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )",
12 | "raw": "G & U S ! B F B"
13 | },
14 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {
15 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )",
16 | "raw": "G & U S ! Y F Y"
17 | },
18 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {
19 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )",
20 | "raw": "G & U S ! C F C"
21 | }
22 | }
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/result.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from dataclasses import dataclass
5 | from typing import Dict, List, Optional, Tuple
6 |
7 | from semantic_parsing_with_constrained_lm.model import ModelResult
8 |
9 |
10 | @dataclass(frozen=True, eq=True)
11 | class DatumResult:
12 | """CLAMP predictions and results for a single datum"""
13 |
14 | # Test datum utterance
15 | test_datum_natural: str
16 |
17 | # Text and cost of each sequence in the final beam
18 | results: List[ModelResult]
19 |
20 | # Text of each sequence in the final beam
21 | # (Duplicated from `results`; maintained here only
22 | # for backwards compatibility. May be removed later.)
23 | outputs: List[str]
24 |
25 | # The metrics dictionary containing the main results
26 | metrics: Dict[str, Optional[str]]
27 |
28 | # Other (optional) test datum fields
29 | test_datum_id: Optional[str] = None
30 | test_datum_turn_part_index: Optional[int] = None
31 | test_datum_agent_context: Optional[str] = None
32 | test_datum_canonical: Optional[str] = None
33 |
34 | # Token-level log probabilities for each sequence in the final beam
35 | # (Not yet implemented)
36 | token_logprobs: Optional[List[List[Tuple[str, float]]]] = None
37 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/scfg_grammar.lark:
--------------------------------------------------------------------------------
1 | %import common.WS
2 | %import common.WORD
3 | %import common.ESCAPED_STRING
4 |
5 | %ignore WS
6 |
7 | start: sync_rule
8 | | mirrored_rule
9 | | macro_rule
10 | | utterance_rule
11 |
12 | start_for_test: sync_rule
13 | | macro_rule
14 | | utterance_rule
15 | | plan_expansion
16 |
17 | sync_rule: rule "->" utterance_expansions "," plan_expansion
18 | mirrored_rule: rule "->" utterance_expansion
19 | macro_rule: macro_def "2>" plan_expansion
20 | utterance_rule: rule "1>" utterance_expansions
21 |
22 | utterance_expansions: utterance_expansion ("|" utterance_expansion)*
23 |
24 | plan_expansion: (token | macro_apply)+
25 | utterance_expansion: token+
26 |
27 | token: terminal
28 | | optional_terminal
29 | | nonterminal
30 | | optional_nonterminal
31 | | empty
32 | | regex
33 |
34 | terminal: terminal_string
35 | optional_terminal: terminal_string "?"
36 |
37 | nonterminal: _name
38 | optional_nonterminal: _name "?"
39 |
40 | rule: _name
41 |
42 | macro_def: _name "(" (_name ("," _name)* ","?)? ")"
43 | | _name
44 | macro_apply: _name "(" (_macro_arg ("," _macro_arg)* ","?)? ")"
45 | _macro_arg: nonterminal | terminal | macro_apply | empty
46 |
47 | _name: /[a-zA-Z][_a-zA-Z0-9]*/
48 | regex: /\/[^\/]+\//
49 |
50 | ?terminal_string: ESCAPED_STRING
51 |
52 | empty: "#e"
53 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/earley/cfg.lark:
--------------------------------------------------------------------------------
1 | // Based on scfg_grammar.lark
2 | %import common.WS
3 | %import common.ESCAPED_STRING
4 | %import common._STRING_ESC_INNER
5 | %import common.INT
6 |
7 | %ignore WS
8 |
9 | _COMMENT: /#[^\n]*/
10 | start: (rule | _COMMENT)*
11 |
12 | rule: nonterminal_lhs "->" expansion
13 |
14 | expansion: alt ("|" alt)*
15 |
16 | ?alt: elem0+
17 | | empty
18 |
19 | ?elem0: elem1
20 | | elem1 "?" -> optional
21 | | elem1 "*" -> star
22 | | elem1 "+" -> plus
23 | | elem1 "{" count "}" -> repeat_exact
24 | | elem1 "{" count ",}" -> repeat_min
25 | | elem1 "{," count "}" -> repeat_max
26 | | elem1 "{" count "," count "}" -> repeat_min_max
27 |
28 | ?elem1: nonterminal_rhs
29 | | terminal
30 | | "(" expansion ")"
31 | | "[" /[^\]]+/ "]" -> char_class
32 | | "[[" /[^\]]+/ "]--[" /[^\]]+/+ "]]" -> char_class_subtract
33 |
34 | HYPHEN: "-"
35 | RIGHT_SQUARE_BRACKET: "]"
36 |
37 | ESCAPED_SINGLE_QUOTED_STRING: "'" _STRING_ESC_INNER "'"
38 | terminal: ESCAPED_STRING | ESCAPED_SINGLE_QUOTED_STRING
39 | // TODO: #e conflicts with syntax for comments
40 | empty: "#e"
41 | nonterminal_lhs: _name
42 | nonterminal_rhs: _name
43 | _name: /[_a-zA-Z][_a-zA-Z0-9]*/
44 | !count: INT
45 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/earley/unicode_categories_spans.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """Provides data about which Unicode characters belong to which general category.
5 |
6 | The concept is explained here:
7 | - https://www.unicode.org/reports/tr44/#General_Category_Values
8 | - https://en.wikipedia.org/wiki/Unicode_character_property#General_Category
9 |
10 | unicode_categories.json was created by translating
11 | https://raw.githubusercontent.com/rust-lang/regex/258bdf798a14f50529c1665e84cc8a3a9e2c90fc/regex-syntax/src/unicode_tables/general_category.rs
12 | """
13 | import functools
14 | import json
15 | from pathlib import Path
16 | from typing import Dict, List
17 |
18 | from semantic_parsing_with_constrained_lm.util.span import Span, SpanSet
19 |
20 |
21 | @functools.lru_cache(maxsize=None)
22 | def raw_data() -> Dict[str, List[List[int]]]:
23 | with open(Path(__file__).absolute().parent / "unicode_categories.json") as f:
24 | return json.load(f)
25 |
26 |
27 | @functools.lru_cache(maxsize=None)
28 | def category_to_span_set(name: str) -> SpanSet:
29 | """Returns the SpanSet for the category name.
30 |
31 | The SpanSet contains the Unicode code points for the corresponding general category.
32 | Only long names with underscores (e.g. "Letter", "Cased_Letter") are accepted."""
33 |
34 | return SpanSet(Span(x, y) for x, y in raw_data()[name])
35 |
--------------------------------------------------------------------------------
/datasets/pick-and-place/train_seed.jsonl:
--------------------------------------------------------------------------------
1 | {"canonical": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "natural": "Look for and pick up any cubes and put them in crate."}
2 | {"canonical": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "natural": "Look for and pick up any non red cubes and put them in crate."}
3 | {"canonical": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "natural": "Look for and pick up any non blue cubes and put them in crate."}
4 | {"canonical": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "natural": "Look for and pick up any non yellow cubes and put them in crate."}
5 | {"canonical": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "natural": "Look for and pick up any non green cubes and put them in crate."}
6 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/dfa_grammar_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from semantic_parsing_with_constrained_lm.earley.grammar import DFAGrammar, Nonterm
5 | from semantic_parsing_with_constrained_lm.earley.specialization import SubstringIntersectingGrammarSpecializer
6 | from semantic_parsing_with_constrained_lm.datum import Datum
7 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse
8 | from semantic_parsing_with_constrained_lm.decoding.uint8_earley_partial_parse import (
9 | UInt8EarleyPartialParse,
10 | UInt8GrammarTokenizerInfo,
11 | )
12 | from semantic_parsing_with_constrained_lm.model import PartialParseBuilder
13 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
14 |
15 |
16 | def create_partial_parse_builder(
17 | grammar: DFAGrammar, tokenizer: ClampTokenizer, utterance_nonterm_name: str
18 | ) -> PartialParseBuilder[Datum]:
19 | specializer = SubstringIntersectingGrammarSpecializer(
20 | grammar, Nonterm(utterance_nonterm_name)
21 | )
22 | tokens = UInt8GrammarTokenizerInfo.prepare_tokens_from_clamp_tokenizer(tokenizer)
23 |
24 | def builder(datum: Datum) -> PartialParse:
25 | specialized_grammar = specializer.specialize(datum.natural)
26 | return UInt8EarleyPartialParse.initial(
27 | UInt8GrammarTokenizerInfo(specialized_grammar, tokens)
28 | )
29 |
30 | return builder
31 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/grammar/quoted.scfg:
--------------------------------------------------------------------------------
1 | # Single Quotes
2 |
3 | single_quoted -> "'" "%"? non_single_quote_star "%"? "'"
4 |
5 | # matches zero or more strings that are either not single quotes, or are two single quotes in a row
6 | non_single_quote_star -> #e
7 | non_single_quote_star -> /[^']/ non_single_quote_star
8 |
9 | # TODO: uncomment?
10 | # non_single_quote_star -> /[']/ /[']/ non_single_quote_star, /[']/ /[']/ non_single_quote_star
11 |
12 |
13 | # Double Quotes
14 |
15 | double_quoted -> "\"" "%"? non_double_quote_star "%"? "\""
16 |
17 | # matches zero or more strings that are either not double quotes, or are two double quotes in a row
18 | non_double_quote_star -> #e
19 | non_double_quote_star -> /[^"]/ non_double_quote_star
20 |
21 | # TODO: uncomment?
22 | # non_double_quote_star -> /["]/ /["]/ non_double_quote_star, /["]/ /["]/ non_double_quote_star
23 |
24 |
25 | # Back ticks
26 |
27 | back_ticked -> "`" non_back_tick_star "`"
28 |
29 | # matches zero or more strings that are either not back ticks, or are two back ticks in a row
30 | non_back_tick_star -> #e
31 | non_back_tick_star -> /[^`]/ non_back_tick_star
32 |
33 | # TODO: uncomment?
34 | # non_back_tick_star -> /[`]/ /[`]/ non_back_tick_star, /[`]/ /[`]/ non_back_tick_star
35 |
36 |
37 | # Brackets
38 |
39 | bracketed -> "[" non_bracket_star "]"
40 |
41 | # matches zero or more strings that are not end brackets
42 | non_bracket_star -> #e
43 | non_bracket_star -> /[^\]]/ non_bracket_star
44 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-syn.train.jsonl:
--------------------------------------------------------------------------------
1 | {"natural": "look for and pick up any cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}
2 | {"natural": "look for and pick up any non red cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}
3 | {"natural": "look for and pick up any non blue cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}
4 | {"natural": "look for and pick up any non yellow cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}
5 | {"natural": "look for and pick up any non green cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}
6 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/kept/pick-syn.train.jsonl:
--------------------------------------------------------------------------------
1 | {"canonical": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "natural": "Look for and pick up any cubes and put them in crate."}
2 | {"canonical": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "natural": "Look for and pick up any non red cubes and put them in crate."}
3 | {"canonical": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "natural": "Look for and pick up any non blue cubes and put them in crate."}
4 | {"canonical": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "natural": "Look for and pick up any non yellow cubes and put them in crate."}
5 | {"canonical": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "natural": "Look for and pick up any non green cubes and put them in crate."}
6 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/earley/context_sensitive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # Tools to help create context-sensitive grammars.
5 | # See tests/test_harbor/earley/context_sensitive_demos for some example uses.
6 |
7 | from abc import ABC, abstractmethod
8 | from dataclasses import dataclass
9 | from typing import Dict, Generic, Iterable, Set
10 |
11 | from semantic_parsing_with_constrained_lm.earley.grammar import DottedRule, Nonterm, RuleResult, Terminal
12 |
13 |
14 | class SelfExpandingNonterm(Generic[Terminal, RuleResult], Nonterm, ABC):
15 | @abstractmethod
16 | def get_expansions(self) -> Iterable[DottedRule[Terminal, RuleResult]]:
17 | """Get the set of expansions for this nonterminal.
18 |
19 | Child classes should contain additional fields which are used to
20 | determine the expansions created."""
21 | pass
22 |
23 |
24 | @dataclass
25 | class DynamicGrammar(Generic[Terminal, RuleResult]):
26 | """A Grammar which knows how to use SelfExpandingNonterm to create expansions."""
27 |
28 | root: Nonterm
29 | fixed_expansions: Dict[Nonterm, Set[DottedRule[Terminal, RuleResult]]]
30 |
31 | def get_expansions(
32 | self, nonterm: Nonterm
33 | ) -> Iterable[DottedRule[Terminal, RuleResult]]:
34 | result = self.fixed_expansions.get(nonterm)
35 | if result is not None:
36 | return result
37 |
38 | if isinstance(nonterm, SelfExpandingNonterm):
39 | return nonterm.get_expansions()
40 |
41 | return ()
42 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/paths.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | from pathlib import Path
6 |
7 |
8 |
9 | DOMAINS_DIR = Path(__file__).resolve().parent / "domains"
10 |
11 | CALFLOW_EXAMPLES_DIR = DOMAINS_DIR / "calflow/data"
12 |
13 | CALFLOW_GRAMMAR_DIR = DOMAINS_DIR / "calflow/grammar"
14 |
15 |
16 | # Paths used in data preparation for BenchClamp
17 | RUN_ON_AML = "AMLT_EXPERIMENT_NAME" in os.environ
18 |
19 | CLAMP_PRETRAINED_MODEL_DIR = (
20 | Path("/mnt/default/huggingface_models/")
21 | if RUN_ON_AML
22 | else Path("huggingface_models/")
23 | )
24 |
25 | CLAMP_DATA_DIR = (
26 | Path("/mnt/default/clamp_data/") if RUN_ON_AML else Path("data")
27 | )
28 |
29 | OVERNIGHT_DATA_DIR = CLAMP_DATA_DIR / "overnight"
30 |
31 | LTL_DATA_DIR = CLAMP_DATA_DIR / "ltl"
32 |
33 | BENCH_CLAMP_DATA_DIR_ROOT = CLAMP_DATA_DIR / "benchclamp"
34 |
35 | BENCH_CLAMP_RAW_DATA_DIR = BENCH_CLAMP_DATA_DIR_ROOT / "raw"
36 |
37 | BENCH_CLAMP_PROCESSED_DATA_DIR = BENCH_CLAMP_DATA_DIR_ROOT / "processed"
38 |
39 | BENCH_CLAMP_GRAMMAR_DATA_DIR = BENCH_CLAMP_DATA_DIR_ROOT / "grammar"
40 |
41 |
42 | # Paths for users of BenchClamp. Kept as strings since Path does not work well with network paths.
43 | CLAMP_DATA_DIR_AZURE = "https://benchclamp.blob.core.windows.net/benchclamp"
44 |
45 | OVERNIGHT_DATA_DIR_AZURE = CLAMP_DATA_DIR_AZURE + "/overnight"
46 |
47 | BENCH_CLAMP_DATA_DIR_ROOT_AZURE = CLAMP_DATA_DIR_AZURE + "/benchclamp"
48 |
49 | BENCH_CLAMP_PROCESSED_DATA_DIR_AZURE = BENCH_CLAMP_DATA_DIR_ROOT_AZURE + "/processed"
50 |
51 | BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE = BENCH_CLAMP_DATA_DIR_ROOT_AZURE + "/grammar"
52 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/sequence_creator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing_extensions import Literal
5 |
6 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum
7 | from semantic_parsing_with_constrained_lm.sequence_creator import SequenceCreator
8 | from semantic_parsing_with_constrained_lm.tokenization import GPT2ClampTokenizer
9 |
10 |
11 | class CoSqlUtterance(SequenceCreator):
12 | def __init__(
13 | self, use_db_val: bool, past_utterances: Literal["none", "one", "all"]
14 | ):
15 | self.use_db_val = use_db_val
16 | self.past_utterances = past_utterances
17 | self.gpt2_tokenizer = GPT2ClampTokenizer.from_pretrained("gpt2")
18 |
19 | def create_sequence(self, datum: BenchClampDatum) -> str:
20 | db_schema = (
21 | datum.db_schema_with_val if self.use_db_val else datum.db_schema_without_val
22 | )
23 | all_past_utterances = datum.utterance.split(" | ")
24 | current_utterance = all_past_utterances[-1]
25 | if self.past_utterances == "none" or len(all_past_utterances) == 1:
26 | past_utterances = ""
27 | elif self.past_utterances == "one":
28 | past_utterances = all_past_utterances[-2]
29 | else:
30 | past_utterances = " | ".join(all_past_utterances[:-1])
31 |
32 | sequence = " , ".join([past_utterances, db_schema, current_utterance]) # type: ignore
33 | sequence_token_ids = self.gpt2_tokenizer.encode(sequence)
34 | # start_index = max(0, len(sequence_token_ids) - 1000)
35 | start_index = 0
36 | return self.gpt2_tokenizer.decode(sequence_token_ids[start_index:])
37 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/datum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from dataclasses import dataclass
5 | from typing import Optional, TypeVar
6 |
7 |
8 | @dataclass(frozen=True, eq=True)
9 | class Datum:
10 | dialogue_id: Optional[str]
11 | turn_part_index: Optional[int]
12 | agent_context: Optional[str]
13 | natural: str
14 |
15 |
16 | @dataclass(frozen=True, eq=True)
17 | class FullDatum(Datum):
18 | canonical: str
19 |
20 |
21 | # Not contravariant since it is produced in a DataRetriever.
22 | # Discussions at https://semanticmachines.slack.com/archives/CM88KH6EN/p1654553264411409
23 | FullDatumSub = TypeVar("FullDatumSub", bound=FullDatum)
24 | # Contravariant since it is ingested by either DataRetriever, DataFilter, or PromptBuilder, but never produced
25 | DatumSub = TypeVar("DatumSub", bound=Datum, contravariant=True)
26 |
27 |
28 | @dataclass(frozen=True, eq=True)
29 | class BenchClampDatum:
30 | """
31 | Class to hold all possible information for each instance in BenchCLAMP. This class is used to generate, read
32 | and write BenchCLAMP data files. We distill it to FullDatum before using training or evaluation.
33 | Fields only used for CalFlow, TreeDST: last_agent_utterance, last_user_utterance, last_plan
34 | Fields only used for Spider and CoSQL: schema_name, db_schema_without_val, db_schema_with_val
35 | """
36 |
37 | dialogue_id: Optional[str]
38 | turn_part_index: Optional[int]
39 | utterance: str
40 | plan: str
41 | last_agent_utterance: Optional[str] = None
42 | last_user_utterance: Optional[str] = None
43 | last_plan: Optional[str] = None
44 | schema_name: Optional[str] = None
45 | db_schema_without_val: Optional[str] = None
46 | db_schema_with_val: Optional[str] = None
47 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """Computes an linear fit for the number of tokens in the input and output.
5 |
6 | This can be used to set a good parameters for the `max_steps` parameter in beam search.
7 |
8 | The script does 10-fold cross-validation to find a slope and intercept which:
9 | - minimizes the mean number of excess steps (i.e. predicted length - gold length)
10 | - only very rarely predicts a length which is smaller than the gold length
11 |
12 | Example invocation:
13 | python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \
14 | --data-path semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \
15 | --tokenizer facebook/bart-large \
16 | --output-type canonicalUtterance # or "lispress"
17 | """
18 | import pathlib
19 | from typing import List, Tuple
20 |
21 | import typer
22 | from transformers import AutoTokenizer
23 |
24 | from semantic_parsing_with_constrained_lm.domains.calflow import (
25 | CalflowOutputLanguage,
26 | read_calflow_jsonl,
27 | )
28 | from semantic_parsing_with_constrained_lm.fit_max_steps import compute_and_print_fit
29 |
30 |
31 | def main(
32 | data_path: pathlib.Path = typer.Option(...),
33 | tokenizer: str = typer.Option(...),
34 | output_type: CalflowOutputLanguage = typer.Option(...),
35 | max_unreachable: int = typer.Option(1),
36 | ):
37 | t = AutoTokenizer.from_pretrained(tokenizer)
38 |
39 | pairs: List[Tuple[int, int]] = []
40 | for datum in read_calflow_jsonl(data_path, output_type):
41 | num_input_tokens = len(t.tokenize(datum.natural))
42 | if not datum.canonical:
43 | continue
44 | num_output_tokens = len(t.tokenize(datum.canonical)) + 1
45 |
46 | pairs.append((num_input_tokens, num_output_tokens))
47 |
48 | compute_and_print_fit(pairs, 10, max_unreachable)
49 |
50 |
51 | if __name__ == "__main__":
52 | typer.run(main)
53 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/util/keydefaultdict.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # Adapted from https://stackoverflow.com/posts/2912455/revisions,
5 | # but with type annotations added.
6 | # pylint: disable=no-member,useless-super-delegation
7 | from typing import Any, Callable, DefaultDict, Iterable, Tuple, TypeVar
8 |
9 | K = TypeVar("K")
10 | V = TypeVar("V")
11 |
12 |
13 | class KeyDefaultDict(DefaultDict[K, V]):
14 | """
15 | A version of defaultdict whose factory function that constructs a
16 | default value for a missing key takes that key as an argument.
17 |
18 | >>> d: KeyDefaultDict[int, str] = KeyDefaultDict(lambda k: "/" + str(k) + "/", {0: "zero"})
19 | >>> d[3] = 'three'
20 | >>> d[0]
21 | 'zero'
22 | >>> d[3]
23 | 'three'
24 | >>> d[4]
25 | '/4/'
26 | >>> dict(d)
27 | {0: 'zero', 3: 'three', 4: '/4/'}
28 | """
29 |
30 | def __init__(self, default_factory: Callable[[K], V], *args: Any, **kwargs: Any):
31 | super().__init__(None, *args, **kwargs)
32 | # store the default_factory in an attribute of a different name, to avoid an inheritance type error
33 | self.default_key_factory = default_factory
34 |
35 | def __missing__(self, key: K) -> V:
36 | """
37 | Overrides the central method of `defaultdict` with one that calls
38 | `default_key_factory` on `key` instead of calling `default_factory`
39 | on 0 args.
40 | """
41 | if self.default_key_factory is None:
42 | raise KeyError(key)
43 | ret = self[key] = self.default_key_factory(key)
44 | return ret
45 |
46 | def __repr__(self) -> str:
47 | """Prints `default_key_factory` instead of `default_factory`."""
48 | return f"{self.__class__.__name__}({self.default_key_factory}, {dict.__repr__(self)})"
49 |
50 | # To avoid E1136 (unsubscriptable-object) pylint errors at call sites
51 | def __getitem__(self, item: K) -> V:
52 | return super().__getitem__(item)
53 |
54 | # To avoid E1136 (unsubscriptable-object) pylint errors at call sites
55 | def __setitem__(self, key: K, value: V) -> None:
56 | return super().__setitem__(key, value)
57 |
58 | def items(self) -> Iterable[Tuple[K, V]]:
59 | return super().items()
60 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/earley/specialization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from dataclasses import dataclass
5 |
6 | from openfst_python import determinize, intersect
7 |
8 | from semantic_parsing_with_constrained_lm.earley.fsa import CompiledDFA
9 | from semantic_parsing_with_constrained_lm.earley.fsa_builders import compile_nfa, re_substring_utf8
10 | from semantic_parsing_with_constrained_lm.earley.grammar import DFADottedRule, DFAGrammar, Nonterm
11 |
12 |
13 | @dataclass
14 | class SubstringIntersectingGrammarSpecializer:
15 | """Creates grammars where the rule for string literals are modified to only allow substrings of a given string.
16 |
17 | The string is usually a user's utterance, where we are performing semantic parsing of that utterance
18 | and we know that any strings in the output program must be substrings of the utterance."""
19 |
20 | base_grammar: DFAGrammar
21 | nonterm_to_intersect: Nonterm
22 |
23 | def __post_init__(self):
24 | assert self.nonterm_to_intersect in self.base_grammar.expansions
25 |
26 | def specialize(self, s: str) -> DFAGrammar:
27 | existing_rule = self.base_grammar.expansions[self.nonterm_to_intersect]
28 | substring_nfa = compile_nfa(re_substring_utf8(s))
29 | assert substring_nfa.edge_indexer.num_indexed() == 0
30 |
31 | # Intersect the FSA which accepts any valid string literal with the FSA
32 | # for the substrings of `s`. This ensures that we don't take any
33 | # substrings of `s` which are incorrectly escaped, e.g. taking " rather
34 | # than \".
35 | # TODO: Have some checks that `s` has already been escaped correctly,
36 | # since otherwise we will end up excluding many substrings which should
37 | # have been included.
38 | new_dfa = CompiledDFA(
39 | determinize(
40 | intersect(existing_rule.dfa.fst, substring_nfa.fst).rmepsilon()
41 | ).minimize(),
42 | existing_rule.dfa.edge_indexer,
43 | )
44 | new_rule = DFADottedRule(existing_rule.lhs, new_dfa, new_dfa.start_id)
45 |
46 | return DFAGrammar(
47 | self.base_grammar.root,
48 | {**self.base_grammar.expansions, self.nonterm_to_intersect: new_rule},
49 | )
50 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/decoding/partial_parse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import Optional, Tuple
6 |
7 | import torch
8 |
9 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
10 |
11 |
12 | class PartialParse(ABC):
13 | @abstractmethod
14 | def allowed_next(
15 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None
16 | ) -> Tuple[Optional[torch.Tensor], bool]:
17 | """Returns possible ways to extend the current prefix.
18 |
19 | The Tensor is of type long and 1-dimensional, with no duplicate values,
20 | containing the IDs of the tokens that we could append.
21 | If it is None, then any token is allowed.
22 | The bool indicates whether we are allowed to terminate here.
23 |
24 | If ordered_ids and top_k are not None, this may optionally return only
25 | the first `top_k` token IDs from ordered_ids which comports with the
26 | grammar, instead of all such token IDs.
27 | """
28 | pass
29 |
30 | @abstractmethod
31 | def append(self, token: int) -> "PartialParse":
32 | """Return a new PartialParse created by appending this token."""
33 | pass
34 |
35 |
36 | class NullPartialParse(PartialParse):
37 | """PartialParse which admits any sequence."""
38 |
39 | def allowed_next(
40 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None
41 | ) -> Tuple[Optional[torch.Tensor], bool]:
42 | return None, True
43 |
44 | def append(self, token: int) -> "PartialParse":
45 | return self
46 |
47 |
48 | class StartsWithSpacePartialParse(PartialParse):
49 | def __init__(self, tokenizer: ClampTokenizer):
50 | valid_tokens = []
51 | for utf8_token, token_id in tokenizer.utf8_token_to_id_map.items():
52 | if utf8_token[0] == 32:
53 | valid_tokens.append(token_id)
54 | self.valid_tokens = torch.tensor(valid_tokens)
55 |
56 | def allowed_next(
57 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None
58 | ) -> Tuple[Optional[torch.Tensor], bool]:
59 | return self.valid_tokens, False
60 |
61 | def append(self, token: int) -> "PartialParse":
62 | return NullPartialParse()
63 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/macro.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from dataclasses import dataclass
5 | from typing import Dict, List, Optional, Tuple, cast
6 |
7 | from semantic_parsing_with_constrained_lm.scfg.parser.token import MacroToken, NonterminalToken, SCFGToken
8 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Expansion
9 |
10 |
11 | @dataclass(frozen=True)
12 | class Macro:
13 | name: str
14 | args: Tuple[str, ...]
15 | expansion: Expansion
16 |
17 | def apply_expression(
18 | self, macro_rules: Dict[str, "Macro"], args_to_bind: List[Expansion]
19 | ) -> Expansion:
20 | """
21 | Apply this macro to the argument. Because all macros cannot use any variables outside of its definition,
22 | we do not need to pass in an environment.
23 | """
24 | result: List[SCFGToken] = []
25 | for token in self.expansion:
26 | result += eval_expression(
27 | macro_rules, token, dict(zip(self.args, args_to_bind))
28 | )
29 | return tuple(result)
30 |
31 |
32 | def eval_expression(
33 | macros: Dict[str, Macro],
34 | token: SCFGToken,
35 | env: Optional[Dict[str, Expansion]] = None,
36 | ) -> Expansion:
37 | """
38 | Given a token, eval it.
39 |
40 | e.g. for macros defined as
41 | f(a) 2> "(" g(a) ")"
42 | g(b) 2> "(" b ")"
43 | h(c) 2> "(" c ")"
44 |
45 | given a macro call f(z),
46 | return "(" "(" z ")" ")"
47 |
48 | or a macro call g(h(z)),
49 | return "(" "(" z ")" ")"
50 |
51 | """
52 | env = env if env else {}
53 | if isinstance(token, MacroToken):
54 | args_to_bind = [eval_expression(macros, arg, env) for arg in token.args]
55 | macro = macros[token.name]
56 | return macro.apply_expression(macros, args_to_bind)
57 | elif isinstance(token, NonterminalToken) and token.value in env:
58 | return env[token.value]
59 |
60 | return (token,)
61 |
62 |
63 | def expand_macros(macros: Dict[str, Macro], expansion: Expansion) -> Expansion:
64 | """
65 | Return a new rule where the plan_rhs has been rewritten with all macros expanded.
66 | """
67 | new_tokens: List[SCFGToken] = []
68 | for token in expansion:
69 | if isinstance(token, MacroToken):
70 | new_tokens += eval_expression(macros, token)
71 | else:
72 | new_tokens.append(cast(SCFGToken, token))
73 |
74 | return tuple(new_tokens)
75 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/overnight/create_benchclamp_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum
5 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import (
6 | OVERNIGHT_DOMAINS,
7 | BenchClampDataset,
8 | )
9 | from semantic_parsing_with_constrained_lm.domains.create_benchclamp_splits import (
10 | create_benchclamp_splits,
11 | )
12 | from semantic_parsing_with_constrained_lm.domains.overnight import OutputType, OvernightDataPieces
13 | from semantic_parsing_with_constrained_lm.paths import (
14 | BENCH_CLAMP_PROCESSED_DATA_DIR,
15 | OVERNIGHT_DATA_DIR,
16 | )
17 |
18 |
19 | def main():
20 | for domain in OVERNIGHT_DOMAINS:
21 | overnight_pieces = OvernightDataPieces.from_dir(
22 | OVERNIGHT_DATA_DIR,
23 | is_dev=True,
24 | domain=domain,
25 | output_type=OutputType.MeaningRepresentation,
26 | simplify_logical_forms=True,
27 | )
28 | train_data = overnight_pieces.train_data
29 | dev_data = overnight_pieces.test_data
30 | overnight_pieces = OvernightDataPieces.from_dir(
31 | OVERNIGHT_DATA_DIR,
32 | is_dev=False,
33 | domain=domain,
34 | output_type=OutputType.MeaningRepresentation,
35 | simplify_logical_forms=True,
36 | )
37 | test_data = overnight_pieces.test_data
38 |
39 | train_benchclamp_data = []
40 | dev_benchclamp_data = []
41 | test_benchclamp_data = []
42 | for data, benchclamp_data in [
43 | (train_data, train_benchclamp_data),
44 | (dev_data, dev_benchclamp_data),
45 | (test_data, test_benchclamp_data),
46 | ]:
47 | for datum in data:
48 | benchclamp_data.append(
49 | BenchClampDatum(
50 | dialogue_id=datum.dialogue_id,
51 | turn_part_index=datum.turn_part_index,
52 | utterance=datum.natural,
53 | plan=datum.canonical,
54 | )
55 | )
56 |
57 | create_benchclamp_splits(
58 | train_benchclamp_data,
59 | dev_benchclamp_data,
60 | test_benchclamp_data,
61 | BENCH_CLAMP_PROCESSED_DATA_DIR / BenchClampDataset.Overnight.value / domain,
62 | )
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {% seo %}
6 |
7 |
8 |
9 |
10 |
11 |
12 | {% include head-custom.html %}
13 | Data-efficient LTL Learning
14 |
15 |
16 | Skip to the content.
17 |
18 |
28 |
29 |
30 | {{ content }}
31 |
32 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/earley/recognize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """
5 | Methods that illustrates the simplest ways to use EarleyChart:
6 | as a recognizer of a token string, or as a generator of grammatical sentences.
7 | """
8 |
9 | from typing import Iterable, Iterator, List, Sequence, cast
10 |
11 | from semantic_parsing_with_constrained_lm.earley.earley import EarleyLRChart, PackedForest
12 | from semantic_parsing_with_constrained_lm.earley.grammar import Grammar, RuleResult, Terminal
13 | from semantic_parsing_with_constrained_lm.earley.input import SequencePosition, SigmaStarTriePosition
14 |
15 |
16 | def parse(
17 | sentence: Iterable[Terminal], grammar: Grammar[Terminal, RuleResult]
18 | ) -> PackedForest[Terminal]:
19 | start_pos = SequencePosition(list(sentence))
20 | chart = EarleyLRChart(grammar=grammar, start_pos=start_pos, use_backpointers=True)
21 | return chart.parse()
22 |
23 |
24 | def is_grammatical(
25 | tokens: Sequence[Terminal], grammar: Grammar[Terminal, RuleResult]
26 | ) -> bool:
27 | """
28 | Tests whether the given input `tokens` are grammatical under `grammar`.
29 | """
30 | start_pos = SequencePosition(tokens)
31 | chart = EarleyLRChart(grammar, start_pos, use_backpointers=False)
32 | for _ in chart.accepting_positions():
33 | return True # we're grammatical if the iterator is non-empty
34 | return False
35 |
36 |
37 | def top_level_rule_results(
38 | tokens: Sequence[Terminal], grammar: Grammar[Terminal, RuleResult]
39 | ) -> Iterable[RuleResult]:
40 | """
41 | Yields the RuleResults produced by the DottedRules from the `start` nonterminal.
42 | """
43 | start_pos = SequencePosition(tokens)
44 | chart = EarleyLRChart(grammar, start_pos, use_backpointers=False)
45 | for end_pos in chart.accepting_positions():
46 | for item, _ in chart.completed_items(grammar.root, start_pos, end_pos):
47 | rule_result = item.dotted_rule.is_final()
48 | assert rule_result is not None
49 | yield rule_result
50 |
51 |
52 | def enumerate_sentences(
53 | grammar: Grammar[Terminal, RuleResult]
54 | ) -> Iterator[List[Terminal]]:
55 | """
56 | Yields grammatical sentences in length order (may not terminate).
57 | """
58 | # root of a Σ* trie with string-labeled edges (as the grammar uses Terminal=str)
59 | start_pos = SigmaStarTriePosition[Terminal]()
60 | chart = EarleyLRChart(grammar, start_pos, use_backpointers=False)
61 | for pos in chart.accepting_positions(): # enumerate nodes in the Σ* trie
62 | # necessary because current typing isn't strong enough
63 | _pos = cast(SigmaStarTriePosition[Terminal], pos)
64 | yield _pos.prefix()
65 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/earley_grammar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # This file is deprecated.
5 | # TODO: Replace with Grammar[np.uint8] and UInt8EarleyPartialParse.
6 | from typing import Any, Dict, List, Tuple, Union
7 |
8 | from semantic_parsing_with_constrained_lm.earley.grammar import FixedGrammar, LinearDottedRule, Nonterm, Symbol
9 | from semantic_parsing_with_constrained_lm.scfg.char_grammar import START
10 | from semantic_parsing_with_constrained_lm.scfg.parser.token import (
11 | EmptyToken,
12 | NonterminalToken,
13 | RegexToken,
14 | TerminalToken,
15 | )
16 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Alias, Expansion, Nonterminal
17 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import GrammarRules
18 |
19 | CFTerminal = Union[str, RegexToken]
20 |
21 |
22 | class EarleyCFGrammar(FixedGrammar[CFTerminal, Any]):
23 | """A grammar for one of the two sides of an SCFG.
24 |
25 | Similar to CharGrammar, but it doesn't split up all terminals into single
26 | characters.
27 | """
28 |
29 | @staticmethod
30 | def from_preprocessed_rules(rules: GrammarRules):
31 | aliased_grammar = {
32 | lhs: [(rhs, None) for rhs in rhss] for lhs, rhss in rules.items()
33 | }
34 | return EarleyCFGrammar.from_aliased_grammar(aliased_grammar) # type: ignore
35 |
36 | @staticmethod
37 | def from_aliased_grammar(
38 | grammar: Dict[Nonterminal, List[Tuple[Expansion, Alias]]]
39 | ) -> "EarleyCFGrammar":
40 | def convert(expansion: Expansion) -> Tuple[Symbol[CFTerminal], ...]:
41 | result = []
42 | for token in expansion:
43 | if isinstance(token, TerminalToken):
44 | result.append(token.render())
45 | elif isinstance(token, RegexToken):
46 | result.append(token)
47 | elif isinstance(token, NonterminalToken):
48 | result.append(Nonterm(token.value))
49 | elif isinstance(token, EmptyToken):
50 | pass
51 | else:
52 | raise ValueError(token)
53 | return tuple(result)
54 |
55 | return EarleyCFGrammar(
56 | root=START,
57 | expansions={
58 | Nonterm(origin): {
59 | # https://github.com/microsoft/pyright/issues/2962
60 | LinearDottedRule[CFTerminal].from_rule(
61 | Nonterm(origin), rhs=convert(rhs), alias=alias
62 | )
63 | for rhs, alias in rhss
64 | }
65 | for origin, rhss in sorted(grammar.items())
66 | },
67 | )
68 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/create_benchclamp_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum
5 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import (
6 | LTL_DOMAINS,
7 | BenchClampDataset,
8 | )
9 | from semantic_parsing_with_constrained_lm.domains.create_benchclamp_splits import create_benchclamp_splits
10 | from semantic_parsing_with_constrained_lm.domains.ltl import LTLOutputType, LTLDataPieces
11 | from semantic_parsing_with_constrained_lm.paths import (
12 | BENCH_CLAMP_PROCESSED_DATA_DIR,
13 | LTL_DATA_DIR,
14 | )
15 |
16 |
17 | def main():
18 | for domain in LTL_DOMAINS:
19 | # domains in this case are "train-evaluation splits"
20 |
21 | """
22 | Naming convention for ltl data files:
23 | - ltl_data.json: contains the canonical data
24 | - denotion looks like something weird, more like program result
25 | - used for building the trie constrained decoder
26 | - either ```canonical``` or ```formula```
27 | - ltl.dev.jsonl
28 | - ltl.test.jsonl
29 | - ltl.train_with_dev.jsonl
30 | - ltl.train_without_dev.jsonl
31 | """
32 | ltl_pieces = LTLDataPieces.from_dir(
33 | LTL_DATA_DIR,
34 | is_dev=True,
35 | domain=domain,
36 | output_type=LTLOutputType.MeaningRepresentation,
37 | simplify_logical_forms=True,
38 | )
39 | train_data = ltl_pieces.train_data
40 | dev_data = ltl_pieces.test_data
41 |
42 | ltl_pieces = LTLDataPieces.from_dir(
43 | LTL_DATA_DIR,
44 | is_dev=False,
45 | domain=domain,
46 | output_type=LTLOutputType.MeaningRepresentation,
47 | simplify_logical_forms=True,
48 | )
49 | test_data = ltl_pieces.test_data
50 |
51 | train_benchclamp_data = []
52 | dev_benchclamp_data = []
53 | test_benchclamp_data = []
54 |
55 | for data, benchclamp_data in [
56 | (train_data, train_benchclamp_data),
57 | (dev_data, dev_benchclamp_data),
58 | (test_data, test_benchclamp_data),
59 | ]:
60 | for datum in data:
61 | benchclamp_data.append(
62 | BenchClampDatum(
63 | dialogue_id=datum.dialogue_id,
64 | turn_part_index=datum.turn_part_index,
65 | utterance=datum.natural,
66 | plan=datum.canonical,
67 | )
68 | )
69 |
70 | create_benchclamp_splits(
71 | train_benchclamp_data,
72 | dev_benchclamp_data,
73 | test_benchclamp_data,
74 | BENCH_CLAMP_PROCESSED_DATA_DIR / BenchClampDataset.LTL.value / domain,
75 | )
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow_eval_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """ Functions to help run Bart cold monster model for Calflow. """
5 | import asyncio
6 | from pathlib import Path
7 | from typing import List
8 |
9 | import tqdm
10 |
11 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import PreprocessedGrammar
12 | from semantic_parsing_with_constrained_lm.configs.lib.calflow import make_semantic_parser_for_calflow
13 | from semantic_parsing_with_constrained_lm.datum import Datum, FullDatum
14 | from semantic_parsing_with_constrained_lm.domains.calflow import CalflowOutputLanguage
15 | from semantic_parsing_with_constrained_lm.lm import ClientType
16 | from semantic_parsing_with_constrained_lm.lm_bart import Seq2SeqBart
17 | from semantic_parsing_with_constrained_lm.model import BeamSearchSemanticParser, ModelResult
18 | from semantic_parsing_with_constrained_lm.train_model_setup import BartModelConfig
19 |
20 |
21 | def instantiate_bart_eval_model(
22 | model_loc: str, grammar_dir: str
23 | ) -> BeamSearchSemanticParser:
24 | preprocessed_grammar = PreprocessedGrammar.from_folder(grammar_dir)
25 | bart_model_config = BartModelConfig(model_id="Bart", model_loc=Path(model_loc))
26 | model, tokenizer, _ = bart_model_config.setup_model()
27 | lm = Seq2SeqBart(
28 | pretrained_model_dir=model_loc, model=model, clamp_tokenizer=tokenizer
29 | )
30 | beam_size = 2
31 | return make_semantic_parser_for_calflow(
32 | [],
33 | lm,
34 | use_gpt3=False,
35 | beam_size=beam_size,
36 | output_type=CalflowOutputLanguage.Canonical,
37 | client_type=ClientType.BART,
38 | preprocessed_grammar=preprocessed_grammar,
39 | constrained=True,
40 | num_examples_per_prompt=0,
41 | )
42 |
43 |
44 | def predict(model: BeamSearchSemanticParser, user_utterance: str) -> List[str]:
45 | results: List[ModelResult] = asyncio.run(
46 | model.predict(
47 | Datum(
48 | natural=user_utterance,
49 | dialogue_id=None,
50 | turn_part_index=None,
51 | agent_context=None,
52 | )
53 | )
54 | )
55 | if len(results) == 0:
56 | return []
57 | return [result.text for result in results]
58 |
59 |
60 | def evaluate(eval_examples: List[FullDatum], model: BeamSearchSemanticParser) -> None:
61 | total = len(eval_examples)
62 | correct = 0
63 | for example in tqdm.tqdm(eval_examples):
64 | predicted = predict(model, example.natural)[0]
65 | if (
66 | example.canonical is not None
67 | and example.canonical.strip() == predicted.strip()
68 | ):
69 | correct += 1
70 | else:
71 | print(f"Utterance: {example.natural}")
72 | print(f"Canonical: {example.canonical}")
73 | print(f"Predicted: {predicted}")
74 |
75 | acc = correct * 1.0 / total
76 | for _ in range(100):
77 | print(f"Accuracy = {correct} / {total} = {acc}")
78 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/finetune/check_for_unks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import importlib
5 | from typing import Callable, Dict, List, Optional, Set
6 |
7 | import typer
8 |
9 | from semantic_parsing_with_constrained_lm.lm import Surround
10 | from semantic_parsing_with_constrained_lm.run_exp import filter_exp_dict
11 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
12 | from semantic_parsing_with_constrained_lm.finetune.lm_finetune import TrainExperiment
13 |
14 |
15 | def find_and_record_unk_tokens(
16 | tokenizer: ClampTokenizer,
17 | surround: Surround,
18 | ids: List[int],
19 | orig: str,
20 | unk_tokens: Set[bytes],
21 | ) -> None:
22 | if surround.starts_with_space:
23 | orig = " " + orig
24 | tokens = tokenizer.tokenize(orig)
25 | ids = ids[len(surround.bos) : len(ids) - len(surround.eos)]
26 | assert len(tokens) == len(ids)
27 |
28 | for token_bytes, token_id in zip(tokens, ids):
29 | if token_id == tokenizer.unk_token_id:
30 | unk_tokens.add(token_bytes)
31 |
32 |
33 | def main(
34 | config_name: str = typer.Option(...),
35 | exp_names: Optional[List[str]] = typer.Option(None),
36 | exp_name_pattern: Optional[List[str]] = typer.Option(None),
37 | ):
38 | config_mod = importlib.import_module(config_name)
39 | exps: Dict[str, Callable[[], TrainExperiment]] = config_mod.build_config() # type: ignore
40 | filtered_exp_dict = filter_exp_dict(exps, exp_names, exp_name_pattern)
41 |
42 | for exp_name in filtered_exp_dict:
43 | exp = filtered_exp_dict[exp_name]()
44 | unk_id = exp.tokenizer.unk_token_id
45 | assert unk_id is not None
46 |
47 | train_dataset = exp.make_clamp_dataset(exp.train_data)
48 | num_unks_in_inputs = 0
49 | num_unks_in_labels = 0
50 |
51 | unk_tokens: Set[bytes] = set()
52 |
53 | for i, _ in enumerate(train_dataset): # type: ignore
54 | datum = train_dataset[i]
55 | input_ids = datum["input_ids"]
56 | labels = datum["labels"]
57 |
58 | if any(t == unk_id for t in input_ids):
59 | num_unks_in_inputs += 1
60 | find_and_record_unk_tokens(
61 | exp.tokenizer,
62 | exp.seq2seq_settings.input_surround,
63 | input_ids,
64 | exp.train_data[i].natural,
65 | unk_tokens,
66 | )
67 |
68 | if any(t == unk_id for t in labels):
69 | num_unks_in_labels += 1
70 | find_and_record_unk_tokens(
71 | exp.tokenizer,
72 | exp.seq2seq_settings.output_surround,
73 | labels,
74 | exp.train_data[i].canonical,
75 | unk_tokens,
76 | )
77 |
78 | print(
79 | f"{exp_name}: {num_unks_in_inputs}/{len(train_dataset)} unks in inputs, "
80 | f"{num_unks_in_labels}/{len(train_dataset)} unks in labels."
81 | )
82 | print(f"unk_tokens = {unk_tokens}")
83 |
84 |
85 | if __name__ == "__main__":
86 | typer.run(main)
87 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/fluenter.scfg:
--------------------------------------------------------------------------------
1 | # non-empty results
2 | boolean -> "does there exist an " constraint_event_ empty , " (> (size (:results (FindEventWrapperWithDefaults :constraint" constraint_event_ "))) #(Number 0))" empty
3 |
4 | # singleton List[Path]
5 | list_path_ -> path , " (append #(List[Path] [])" path ")"
6 | list_path_ -> path , " (append (List.Nil)" path ")"
7 |
8 | # singleton List[Recipient]
9 | list_recipient_ -> recipient , " (append #(List[Recipient] [])" recipient ")"
10 | list_recipient_ -> recipient , " (append (List.Nil)" recipient ")"
11 |
12 | unit -> "Yes, do the " number " one" , " (Yield :output (Execute :intension (ChooseCreateEvent :index" number " :intension (refer (ActionIntensionConstraint)))))"
13 |
14 | constraint_list_attendee__ -> " with attendees" , " (negate (AlwaysFalseConstraint[List[Attendee]]))"
15 |
16 | constraint_event__constraint_event__args -> " with location unspecified" constraint_event__constraint_event__args , " :location (AlwaysFalseConstraint[LocationKeyphrase])" constraint_event__constraint_event__args
17 |
18 | # TODO: ambiguous with `(getIntraSalient (AlwaysTrueConstraint))`, does that matter?
19 | dynamic -> "that thing" , "(:item (getIntraSalient (AlwaysTrueConstraint[Dynamic])))"
20 |
21 | # comparisons that aren't "before"/"after"
22 | constraint_duration_ -> " longer than " duration , " (?>" duration ")"
23 | constraint_duration_ -> " no shorter than " duration , " (?>=" duration ")"
24 | constraint_duration_ -> " shorter than " duration , " (?<" duration ")"
25 | constraint_duration_ -> " no longer than " duration , " (?<=" duration ")"
26 |
27 | # update event without restating type
28 | constraint_event_wo_type -> constraint_event__constraint_event__args , " (Constraint[Event]" constraint_event__constraint_event__args ")"
29 | updateeventresponse -> "update " event " so it is" constraint_event_wo_type , " (UpdateWrapper" " :findArg" event " :updateArg" constraint_event_wo_type ")"
30 | updatecommitevent -> "update " eventid " so it is" constraint_event_wo_type , " (UpdatePreflightEventWrapper" " :id" eventid " :update" constraint_event_wo_type ")"
31 |
32 | # clobber event without restating type
33 | dynamic -> "Change my request so the " constraint_calflowintension_constraint_event___ " is" constraint_event_wo_type , " (ClobberWrapper" " :oldLocation" constraint_calflowintension_constraint_event___ " :new" constraint_event_wo_type ")"
34 |
35 | # update duration
36 | duration -> duration " longer" , " (addDurations (:duration (getIntraSalient (AlwaysTrueConstraint[Event])))" duration ")"
37 |
38 | # bare date constraint
39 | constraint_date_ -> "date" , " (Constraint[Date])"
40 |
41 | # nonEmptyBase
42 | non_empty_base_ -> constraint_event_ , " :nonEmptyBase" constraint_event_
43 | constraint_event__constraint_event__args_a -> constraint_event__constraint_event__args , constraint_event__constraint_event__args
44 | constraint_event__constraint_event__args_b -> constraint_event__constraint_event__args , constraint_event__constraint_event__args
45 | constraint_event_ -> non_empty_base_ " but" constraint_event__constraint_event__args_a constraint_event__constraint_event__args_b , " (Constraint[Event]" constraint_event__constraint_event__args_a non_empty_base_ constraint_event__constraint_event__args_b ")"
46 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/token.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import json
5 | from abc import ABC, abstractmethod
6 | from dataclasses import dataclass
7 | from typing import List, Tuple
8 |
9 | import regex
10 | from cached_property import cached_property
11 |
12 | ASCII_CHARS: List[str] = [chr(i) for i in range(32, 127)]
13 |
14 |
15 | class SCFGToken(ABC):
16 | def render(self) -> str:
17 | """
18 | How to render this token when generating it. In most cases, you can just render the underlying value.
19 | Sometimes you want to modify it, like in TerminalToken.
20 | """
21 | return self.value
22 |
23 | @property
24 | @abstractmethod
25 | def value(self) -> str:
26 | """The underlying value of the token."""
27 | pass
28 |
29 | @property
30 | def lark_value(self) -> str:
31 | return self.value
32 |
33 |
34 | class OptionableSCFGToken(SCFGToken, ABC):
35 | optional: bool
36 |
37 |
38 | @dataclass(frozen=True)
39 | class NonterminalToken(OptionableSCFGToken):
40 | underlying: str
41 | optional: bool
42 |
43 | @property
44 | def value(self):
45 | return self.underlying
46 |
47 | def is_regex(self):
48 | return self.underlying[0] == "/"
49 |
50 |
51 | @dataclass(frozen=True)
52 | class TerminalToken(OptionableSCFGToken):
53 | underlying: str
54 | optional: bool
55 |
56 | def render(self):
57 | """
58 | Remove the outermost quotes and unescape the rest of the quotes.
59 | """
60 | return json.loads(self.underlying)
61 |
62 | @property
63 | def value(self):
64 | return self.underlying
65 |
66 | @property
67 | def lark_value(self):
68 | return self.value + "i"
69 |
70 |
71 | @dataclass(frozen=True)
72 | class MacroToken(SCFGToken):
73 | name: str
74 | args: Tuple[SCFGToken, ...]
75 |
76 | @property
77 | def value(self):
78 | return f"{self.name}({','.join([a.value for a in self.args])})"
79 |
80 |
81 | @dataclass(frozen=True)
82 | class EmptyToken(SCFGToken):
83 | @property
84 | def value(self):
85 | return ""
86 |
87 |
88 | @dataclass(frozen=True)
89 | class RegexToken(NonterminalToken):
90 | prefix: str
91 |
92 | def render_matching_value(self, value: str) -> str:
93 | return self.prefix + value
94 |
95 | @property
96 | def value(self) -> str:
97 | return self.underlying
98 |
99 | @property
100 | def lark_value(self) -> str:
101 | if (
102 | self.prefix
103 | ): # We need to have this condition because lark gets mad if you give it an empty token.
104 | return f'"{self.prefix}" ' + self.underlying
105 | else:
106 | return self.underlying
107 |
108 | @cached_property
109 | def compiled(self) -> "regex.Pattern[str]":
110 | assert self.underlying.startswith("/") and self.underlying.endswith("/")
111 | return regex.compile(self.underlying[1:-1])
112 |
113 | @cached_property
114 | def ascii_chars(self) -> List[str]:
115 | return [c for c in ASCII_CHARS if self.compiled.match(c)]
116 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/util/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import contextlib
5 | import io
6 | import pathlib
7 | import sys
8 | import traceback
9 | import warnings
10 | from dataclasses import dataclass
11 | from io import SEEK_SET, TextIOBase
12 | from typing import Iterator, List, Optional, TextIO
13 |
14 |
15 | @dataclass
16 | class Tee(TextIOBase):
17 | """ "A write-only file-like object that forwards writes to `sinks`."""
18 |
19 | sinks: List[TextIO]
20 | closed: bool = False
21 |
22 | def close(self) -> None:
23 | self.flush()
24 | self.closed = True
25 |
26 | def fileno(self) -> int:
27 | raise OSError
28 |
29 | def flush(self) -> None:
30 | for sink in self.sinks:
31 | sink.flush()
32 |
33 | def isatty(self) -> bool:
34 | return False
35 |
36 | def readable(self) -> bool:
37 | return False
38 |
39 | def readline(self, size=-1) -> str: # type: ignore
40 | raise io.UnsupportedOperation
41 |
42 | def readlines(self, hint=-1) -> List[str]: # type: ignore
43 | raise io.UnsupportedOperation
44 |
45 | def seek(self, offset, whence=SEEK_SET) -> int:
46 | raise io.UnsupportedOperation
47 |
48 | def seekable(self) -> bool:
49 | return False
50 |
51 | def tell(self) -> int:
52 | raise io.UnsupportedOperation
53 |
54 | def truncate(self, size=None):
55 | raise io.UnsupportedOperation
56 |
57 | def writable(self) -> bool:
58 | return True
59 |
60 | def writelines(self, lines: List[str]) -> None:
61 | for sink in self.sinks:
62 | sink.writelines(lines)
63 |
64 | @property
65 | def encoding(self) -> str:
66 | return self.sinks[0].encoding
67 |
68 | @property
69 | def errors(self) -> Optional[str]:
70 | return self.sinks[0].errors
71 |
72 | def detach(self) -> None:
73 | raise io.UnsupportedOperation
74 |
75 | def read(self, size=-1) -> str:
76 | raise io.UnsupportedOperation
77 |
78 | def write(self, s: str) -> int:
79 | results: List[int] = []
80 | for sink in self.sinks:
81 | results.append(sink.write(s))
82 | if not all(r == results[0] for r in results[1:]):
83 | warnings.warn("Sinks wrote different number of characters", ResourceWarning)
84 | return results[0]
85 |
86 |
87 | @contextlib.contextmanager
88 | def intercept_output(
89 | stdout_path: pathlib.Path, stderr_path: pathlib.Path
90 | ) -> Iterator[None]:
91 | """Write all stdout and stderr to both the screen and these files."""
92 |
93 | with open(stdout_path, "a") as stdout_file, open(stderr_path, "a") as stderr_file:
94 | true_stdout = sys.stdout
95 | true_stderr = sys.stderr
96 | sys.stdout = Tee([true_stdout, stdout_file]) # type: ignore
97 | sys.stderr = Tee([true_stdout, stderr_file]) # type: ignore
98 | try:
99 | yield
100 | except:
101 | traceback.print_exc(file=stderr_file)
102 | raise
103 | finally:
104 | sys.stdout = true_stdout
105 | sys.stderr = true_stderr
106 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/cosql/seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """
5 | ported from
6 | https://github.com/ElementAI/picard/blob/5ddd6cb9f74efca87d4604d5ddddc1b638459466/seq2seq/utils/cosql.py#L10
7 | Under the Apache 2 licence:
8 | https://github.com/ElementAI/picard/blob/main/LICENSE
9 | """
10 | from typing import Any, Dict, List
11 |
12 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.content_encoder import (
13 | get_database_matches,
14 | )
15 |
16 |
17 | def get_input(
18 | utterances: List[str], serialized_schema: str, prefix: str, sep: str = " | "
19 | ) -> str:
20 | # "[prefix] [utterance n] [serialized schema] || [utterance n-1] | [utterance n-2] | ..."
21 | if len(utterances) > 1:
22 | reversed_utterance_head = (
23 | utterance.strip() for utterance in reversed(utterances[:-1])
24 | )
25 | serialized_reversed_utterance_head = " || " + sep.join(reversed_utterance_head)
26 | else:
27 | serialized_reversed_utterance_head = ""
28 | return (
29 | prefix
30 | + utterances[-1].strip()
31 | + " "
32 | + serialized_schema.strip()
33 | + serialized_reversed_utterance_head
34 | )
35 |
36 |
37 | def serialize_schema(
38 | question: str,
39 | db_path: str,
40 | db_id: str,
41 | db_column_names: Dict[str, Any],
42 | db_table_names: List[str],
43 | schema_serialization_with_db_content: bool = True,
44 | ) -> str:
45 | # schema_serialization_with_db_id: bool = True
46 | # schema_serialization_randomized: bool = False
47 | # normalize_query: bool = True
48 | # schema_serialization_type = "peteshaw"
49 | # see https://github.com/google-research/language/blob/master/language/nqg/tasks/spider/append_schema.py#L42
50 | db_id_str = " | {db_id}"
51 | table_sep = ""
52 | table_str = " | {table} : {columns}"
53 | column_sep = " , "
54 | column_str_with_values = "{column} ( {values} )"
55 | column_str_without_values = "{column}"
56 | value_sep = " , "
57 |
58 | def get_column_str(table_name: str, column_name: str) -> str:
59 | column_name_str = column_name.lower()
60 | if schema_serialization_with_db_content:
61 | matches = get_database_matches(
62 | question=question,
63 | table_name=table_name,
64 | column_name=column_name,
65 | db_path=(db_path + "/" + db_id + "/" + db_id + ".sqlite"),
66 | )
67 | if matches:
68 | return column_str_with_values.format(
69 | column=column_name_str, values=value_sep.join(matches)
70 | )
71 | else:
72 | return column_str_without_values.format(column=column_name_str)
73 | else:
74 | return column_str_without_values.format(column=column_name_str)
75 |
76 | tables = [
77 | table_str.format(
78 | table=table_name.lower(),
79 | columns=column_sep.join(
80 | map(
81 | lambda y: get_column_str(
82 | table_name=table_name, # pylint: disable=cell-var-from-loop
83 | column_name=y[1],
84 | ),
85 | filter(
86 | lambda y: y[0]
87 | == table_id, # pylint: disable=cell-var-from-loop
88 | zip(
89 | db_column_names["table_id"], db_column_names["column_name"]
90 | ),
91 | ),
92 | )
93 | ),
94 | )
95 | for table_id, table_name in enumerate(db_table_names)
96 | ]
97 | return db_id_str.format(db_id=db_id) + table_sep.join(tables)
98 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/cosql/grammar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Iterator, List
5 |
6 | from semantic_parsing_with_constrained_lm.util.types import StrPath
7 | from semantic_parsing_with_constrained_lm.scfg.parser.rule import (
8 | Rule,
9 | SyncRule,
10 | mirrored_rule,
11 | nonterm,
12 | term,
13 | )
14 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import PreprocessedGrammar
15 | from semantic_parsing_with_constrained_lm.scfg.scfg import SCFG
16 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.paths import SQL_GRAMMAR_DIR
17 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.schema import DbSchema
18 |
19 | # NTs in the SQL scfg
20 | SCHEMA_NAME = "schema_name"
21 | TABLE_NAME = "table_name"
22 | TABLE_ALIAS = "table_alias"
23 | COLUMN_NAME = "column_name"
24 | POSSIBLY_QUALIFIED = "possibly_qualified_column_name"
25 | DOT = "DOT"
26 | WS_STAR = "ws_star"
27 | SCHEMA_DOT_WS = "schema_dot_ws"
28 |
29 |
30 | def export_sync_rules(db: DbSchema) -> Iterator[SyncRule]:
31 | """
32 | Exports schema-specific SCFG rules for the NTs
33 | `schema`, `table_name`, `column_name`, and `possibly_qualified_column_name`.
34 | Ensures that `(({s}.)?{t}.)?{c}` is accepted iff `c` is a column in table `t` in schema `s`.
35 | """
36 | # helpers
37 | def to_lower(lhs: str, rhs: str):
38 | """case-insensitive match, normalize to lowercase"""
39 | return SyncRule(
40 | lhs=lhs,
41 | # case-insensitive
42 | utterance_rhss=(
43 | tuple(nonterm(c.upper()) if c.isalpha() else term(c) for c in rhs),
44 | ),
45 | # return lower case
46 | plan_rhs=(term(rhs.lower()),),
47 | )
48 |
49 | rhs = (term(db.name),)
50 | yield mirrored_rule(SCHEMA_NAME, rhs)
51 | for table in db.tables:
52 | tbl_name = table.name
53 | my_tbl_name_nt = f"table_called_{tbl_name}"
54 | yield mirrored_rule(TABLE_NAME, (nonterm(my_tbl_name_nt),))
55 | # case-insensitive, normalize to lowercase
56 | yield to_lower(lhs=my_tbl_name_nt, rhs=tbl_name)
57 | # allow table alias
58 | yield mirrored_rule(my_tbl_name_nt, (nonterm(TABLE_ALIAS),))
59 | qualifier = f"qualifier_for_{tbl_name}"
60 | rhs3 = (
61 | nonterm(SCHEMA_DOT_WS, optional=True),
62 | nonterm(my_tbl_name_nt),
63 | nonterm(WS_STAR),
64 | nonterm(DOT),
65 | nonterm(WS_STAR),
66 | )
67 | yield mirrored_rule(qualifier, rhs3)
68 | col_nt = f"column_for_{tbl_name}"
69 | yield mirrored_rule(COLUMN_NAME, ((nonterm(col_nt)),))
70 | # ensures that `table.column` is accepted iff column is in table
71 | poss_qual_rhs = (nonterm(qualifier, optional=True), (nonterm(col_nt)))
72 | yield mirrored_rule(POSSIBLY_QUALIFIED, poss_qual_rhs)
73 | for column in table.all_columns():
74 | col_name = column.name
75 | # case-insensitive, normalize to lowercase
76 | yield to_lower(lhs=col_nt, rhs=col_name)
77 |
78 | # This will allow copying from database values.
79 | for val in db.values:
80 | t_val = (term(val),)
81 | yield mirrored_rule("non_single_quote_star", t_val)
82 | yield mirrored_rule("non_double_quote_star", t_val)
83 |
84 |
85 | def load_base_grammar(folder_path: StrPath = SQL_GRAMMAR_DIR) -> PreprocessedGrammar:
86 | return PreprocessedGrammar.from_folder(folder_path)
87 |
88 |
89 | def preprocessed_grammar_for_schema(
90 | db: DbSchema, base_grammar: PreprocessedGrammar
91 | ) -> PreprocessedGrammar:
92 | rules: List[Rule] = list(export_sync_rules(db))
93 | return base_grammar.merge(PreprocessedGrammar.from_rules(rules))
94 |
95 |
96 | def grammar_for_schema(db: DbSchema, base_grammar: PreprocessedGrammar) -> SCFG:
97 | return SCFG(preprocessed_grammar_for_schema(db, base_grammar))
98 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/async_tools/batch_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import abc
5 | import dataclasses
6 | from abc import ABC
7 | from dataclasses import dataclass
8 | from typing import Callable, Dict, Generic, List, Tuple, TypeVar, Union
9 |
10 | from semantic_parsing_with_constrained_lm.util.unit import UNIT, Unit
11 | from semantic_parsing_with_constrained_lm.async_tools.limits import TimeoutBarrier
12 |
13 | I = TypeVar("I")
14 | O = TypeVar("O")
15 |
16 |
17 | class BatchMaker(Generic[I, O], ABC):
18 | """Identifies which elements could be batched together, and specifies how to do the batched operation.
19 |
20 | This should also derive from collections.abc.Hashable, but then it is not
21 | possible to use a dataclass to create a BatchMaker.
22 | """
23 |
24 | @property
25 | @abc.abstractmethod
26 | def max_batch_size(self) -> int:
27 | """Maximum number of elements to batch together."""
28 | pass
29 |
30 | @property
31 | @abc.abstractmethod
32 | def timeout(self) -> float:
33 | """Time to wait before running `execute`."""
34 | pass
35 |
36 | @abc.abstractmethod
37 | async def execute(self, inputs: List[I]) -> O:
38 | """Batched operation on inputs."""
39 | pass
40 |
41 |
42 | @dataclass
43 | class PendingContainer(Generic[I, O]):
44 | """Used inside BatchingHelper."""
45 |
46 | batch_key: BatchMaker[I, O]
47 | batching_helper: "BatchingHelper[I, O]"
48 |
49 | inputs: List[I] = dataclasses.field(default_factory=list)
50 | barrier: TimeoutBarrier = dataclasses.field(init=False)
51 | result: Union[O, Unit] = UNIT
52 |
53 | def __post_init__(self):
54 | self.barrier = TimeoutBarrier(
55 | self.batch_key.max_batch_size, self.batch_key.timeout, self._execute
56 | )
57 |
58 | async def enqueue_and_wait(self, inp: I) -> Tuple[O, int]:
59 | i = len(self.inputs)
60 | self.inputs.append(inp)
61 | await self.barrier.arrive_and_wait()
62 | assert not isinstance(self.result, Unit)
63 | return self.result, i
64 |
65 | @property
66 | def closed(self) -> bool:
67 | return bool(self.barrier.currently_releasing or self.result is not UNIT)
68 |
69 | async def _execute(self) -> None:
70 | self.result = await self.batch_key.execute(self.inputs)
71 | self.batching_helper._del_pending_container( # pylint: disable=protected-access
72 | self
73 | )
74 |
75 |
76 | @dataclass
77 | class BatchingHelper(Generic[I, O]):
78 | """Helper for running functions on batched inputs."""
79 |
80 | # Creates a BatchMaker from the input. Inputs with BatchMakers that compare equal are eligible for batching together.
81 | input_to_batch_maker: Callable[[I], BatchMaker[I, O]]
82 |
83 | # Pending operations per BatchMaker.
84 | pending: Dict[BatchMaker[I, O], PendingContainer[I, O]] = dataclasses.field(
85 | default_factory=dict
86 | )
87 |
88 | async def execute(self, inp: I) -> Tuple[O, int]:
89 | """Given an input of type I, this class uses `batch_key_fn` to create a BatchKey.
90 | Inputs with the same BatchKey are coalesced together.
91 | After a certain number of inputs are collected, or a timeout passes,
92 | we run BatchKey.execute to produce a batched output of type O.
93 |
94 | Returns the output O with a batch index to locate the result for the input within O."""
95 |
96 | batch_maker = self.input_to_batch_maker(inp) # type: ignore[call-arg]
97 | pending_container = self.pending.get(batch_maker)
98 |
99 | if pending_container is None or pending_container.closed:
100 | self.pending[batch_maker] = pending_container = PendingContainer(
101 | batch_maker, self
102 | )
103 |
104 | return await pending_container.enqueue_and_wait(inp)
105 |
106 | def _del_pending_container(self, container: PendingContainer[I, O]):
107 | """When a PendingContainer is done executing, immediately remove it from `self.pending`"""
108 | if self.pending.get(container.batch_key) is container:
109 | del self.pending[container.batch_key]
110 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 | ## Intro
6 |
7 | To make robots accessible to a broad audience, it is critical to endow them with the ability to take universal modes of communication, like commands given in natural language, and extract a concrete desired task specification, defined using a formal language like linear temporal logic (LTL).
8 |
9 | In this paper, We present a learning-based approach to translate from natural language commands to LTL specifications with very limited human-labeled training data by leveraging Large Language Models (LLMs). Our model can translate natural language commands at 75% accuracy with about 12 annotations and when given full training data, achieves state-of-the-art performance. We also show how its outputs can be used to plan long-horizon, multi-stage tasks on a 12D quadrotor.
10 |
11 |
12 |
13 | ## Approach
14 |
15 | 
16 |
17 | Given a predefined set of possible LTL formulas and atomic
18 | propositions, and up to one natural language annotation for
19 | each formula, we first translate these pre-defined formulas to
20 | (structured) English and then use the paraphrasing abilities of modern LLMs to synthesize a large corpus of diverse natural language commands.
21 |
22 | Given this corpus, we use the data to fine-tune an LLM. Here, we explore two variants, where for training labels we use 1) raw LTL formulas, or 2) a canonical form of the LTL formulas (an intermediate representation that is more similar to English).
23 | At evaluation time, we enforce the LLM’s output to be syntactically correct via constrained decoding.
24 |
25 | ## Result
26 |
27 | We evaluate our methods on three datasets, each associated with a different task and environment. We show that our model can translate natural language commands at 75% accuracy with about 12 annotations and when given full training data, constantly obtain state-of-the-art performance. In the paper, we also present comprehensive ablation studies that validate the necessity of each design decision we make.
28 |
29 | ### Results in low-data regimes
30 |
31 | In this setup, models are trained with only 12 human annotations at most, and evaluated on the entire original dataset. Our model significantly advances the state-of-the-art in this regime, achieving 75% average accuracy.
32 |
33 | | Model architecture | Training data | Test data | Drone Dataset | Cleanup Dataset | Pick Dataset |
34 | | ------------------------ | ------------- | ----------- | ------------- | --------------- | ------------ |
35 | | RNN | synthetic | full golden | 22.41 | 52.54 | 32.39 |
36 | | CopyNet | synthetic | full golden | 36.41 | 53.40 | 40.36 |
37 | | BART-FT-Raw (ours) | synthetic | full golden | **69.39** | **78.00** | **81.45** |
38 | | BART-FT-Canonical (ours) | synthetic | full golden | 68.38 | 77.90 | 78.23 |
39 |
40 | ### Results in standard data regimes
41 |
42 | In this setup, we follow the settings of previous works, where models are evaluated by five-fold cross-validation on the entire dataset. Our model consistently outperforms the state-of-the-art in this regime, with about 1% accuracy improvement on average.
43 |
44 | | Model architecture | Training data | Test data | Drone Dataset | Cleanup Dataset | Pick Dataset |
45 | | ------------------------ | ------------- | ---------- | ------------- | --------------- | ------------ |
46 | | RNN | 4/5 golden | 1/5 golden | 87.18 | 95.51 | 93.78 |
47 | | CopyNet | 4/5 golden | 1/5 golden | 88.97 | 95.47 | 93.14 |
48 | | BART-FT-Raw (ours) | 4/5 golden | 1/5 golden | 90.78 | **97.84** | **95.97** |
49 | | BART-FT-Canonical (ours) | 4/5 golden | 1/5 golden | **90.86** | 97.81 | 95.70 |
50 |
51 | ### Demo
52 |
53 | 
54 |
55 | Finally, we also show how its outputs can be used to plan long-horizon, multi-stage tasks on a 12D quadrotor in simulation.
56 |
57 | ## Cite
58 |
59 | If you find this work useful, please cite our paper:
60 |
61 | ```
62 | TO BE RELEASED
63 | ```
64 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/data/ids_dev_100_uniform.txt:
--------------------------------------------------------------------------------
1 | a4363ada-270f-453a-92c3-72dd95f6b8c4,1
2 | c244d5d1-9728-447c-8cb1-875fc92a6bc1,3
3 | 20e6b0e7-f1e6-46e4-bb4a-d60bd9e11100,2
4 | 775152fc-214c-4528-af2d-e70ab7d52304,0
5 | fc861e3d-9804-4942-91fd-5b292cf9aff7,1
6 | 81cf24a3-88ed-4545-b3ee-53168f2d2caf,3
7 | 0f29448a-3bd2-4271-99ce-d44428b2d638,0
8 | 0dc2c70e-35e8-4012-a170-196a6615aa70,2
9 | d746be41-3bb5-497b-a9b0-b264be569990,0
10 | 3aaa3b10-e339-4b31-b1fb-afcc42f8770f,0
11 | 67e98197-84a8-4434-84d6-28fc38a917cc,6
12 | b93552d6-b070-4e0a-9e9f-49681dc2c594,4
13 | 0dc438b5-b8be-4502-bc90-0f732589edfb,2
14 | cb13f9a2-e233-4711-8895-fb5620785880,1
15 | fcca3bce-e680-4685-af40-d979661842d8,0
16 | 9ed7b59f-f57d-4c71-8010-9b96d1aa993a,1
17 | 0771eed4-3107-4485-be63-79247f3f984f,4
18 | c4e474fb-4f6c-4182-b688-604f6d1a9fdf,2
19 | 15b0f0e5-023c-4a2d-9f66-4f47c9b95e51,4
20 | ce50f9e6-d765-438a-b4d7-bc940facd561,2
21 | df190a44-c2de-4e66-a1a4-b12c092d7e86,1
22 | 3f4520ea-ccbc-42a3-9fc2-25d948ee680d,0
23 | a1cdff70-2b42-48f5-9971-410e9888796a,3
24 | 24139643-6c8d-4dcb-827a-542ce25d7cdb,3
25 | 6c4d9f4f-79c2-4436-9de0-bcfeb3a81b4f,1
26 | 8d97df11-7b3f-46a9-b05f-25d187638537,0
27 | 822a7a20-b65c-4225-b460-528f18468e00,4
28 | 9b19eb00-516d-4df6-9be1-d281aa8628c6,4
29 | 2a20ce5f-21dd-424a-9191-79e4f077f59c,3
30 | 65e11734-3e24-437a-9aba-ac59c0971d6f,3
31 | 6eae204f-796f-4d6d-a32b-da0e894c8b74,3
32 | df6d7f07-d714-4d7f-86b8-b4e18ed09c92,3
33 | 463b48d9-5bde-4609-8775-8fc1b28b61de,1
34 | fceb4fa5-a82f-4683-a017-08abaf715294,0
35 | 8c51ae7b-4516-458c-93b6-ce39d049101e,1
36 | fee73a1c-c636-4064-aad1-e876a252bf80,3
37 | b71de4dd-38e2-48ec-a677-97cf7333d8cb,1
38 | 0b712121-8c14-4f5a-97df-51ad6db9bc82,0
39 | 65022059-0d96-4d78-b27c-5211e8183d52,2
40 | 0c219822-49b8-4c4c-b28c-be7a03f5b913,1
41 | ec9ab74c-a3df-456b-bcd8-09a35319581f,1
42 | b57c2ff5-3c6a-49c1-bf51-74e046002b05,1
43 | 71df0bc7-60d0-4fd3-9c18-3859eccb15f3,2
44 | c747df82-0e9a-4ba1-8220-4f7678ca7a55,1
45 | f20d19c2-5a46-477b-a373-0fcabab2dbda,4
46 | ca77251b-093b-48ef-85d4-e337d8a6336e,5
47 | 4c6da53c-baa1-4b95-b679-0cd1c1c190e5,2
48 | 6aec2974-84b3-4800-a1f7-01b504b7b20f,1
49 | a26af847-4626-4102-aeb4-6bc0dce34d89,0
50 | 99eea3c6-d9d0-4856-8393-ee33bb8bb489,6
51 | c6c9f5d3-b799-4cd2-84ab-57eaaca3987f,4
52 | 71d87d1f-e9d2-4bd1-b8d2-4b071d794f2e,6
53 | d69c16d4-385b-4823-8c81-79f560bac3cc,2
54 | 50baad9f-2fde-4e30-9f91-40819246995d,1
55 | dd54981c-5c4e-4c8e-9feb-85ab880bb879,1
56 | 1fd83edb-2cf2-48c1-9efe-b15fc01e4138,3
57 | 7d1cbec1-d6d2-4f0b-8734-527192f1d509,1
58 | 402cc271-4ae3-4b02-858a-0d2363ffedf0,4
59 | 9260a5e9-d72d-4131-8bac-c5ba518f83a7,3
60 | 0ba7de27-a741-443f-947f-51ffc7a5662c,2
61 | 98f2551d-a382-46c6-9c82-3d466d299d7d,0
62 | 527d5ff9-2e02-4dae-9484-b82844e87f1c,1
63 | 00f61607-9691-43d6-ab00-ba3e09e2195e,2
64 | 28e9f1f2-8dd2-416c-a0f1-101f290e5823,2
65 | 31dd1a17-7f18-445b-9e59-6b7c052adaa6,0
66 | 746c622d-9bbd-449a-bf45-d7d78dda32e8,4
67 | 80f7625b-cc09-4666-9607-1f27127adf97,2
68 | 7da05ea0-9959-49e7-b0c4-a771e64eeaed,2
69 | fadc029a-082c-4bd1-aab2-ea488afa8cc9,0
70 | 69168e02-76e1-4dc1-9410-d5d3e34eadef,3
71 | c296c911-2589-4982-8863-e06909a01080,5
72 | ebffca74-0373-421a-a422-f58406243f1e,3
73 | 9b19eb00-516d-4df6-9be1-d281aa8628c6,1
74 | 1ef4362f-0bcd-4209-b610-0864a24d2fdf,1
75 | 8613da88-0a12-4855-a420-4eea91fe6a32,3
76 | e4894296-8e2e-4759-ae55-9dc6c1339777,2
77 | de9aed72-9598-46d0-b24a-50d2e7f49440,2
78 | 3ac07aec-8c21-46c6-817d-1794b082a62f,2
79 | 3fdf3aad-a740-40c1-938d-95a449126ae5,2
80 | d14c5147-b7c9-4edb-89fb-79632ebc3ef9,12
81 | 31811a76-59ea-494f-8b8d-ec477f7bebca,2
82 | e8be360b-fd9d-4504-96c3-526a02bbe46d,3
83 | b4b34e07-f1c8-4c89-b051-e525ec79c8d1,12
84 | cbfd6f78-2f13-41f2-9a3e-3d7844fbee57,3
85 | 1ad516fd-f452-4af0-b877-b5d4f6b47806,0
86 | 2a4c7de3-a990-43da-93a6-851370ed1c3f,2
87 | 3745161d-c0d1-41a0-9baa-6bb77be2b876,3
88 | caf31ca2-8e7b-40d8-91b9-a8e2a4fe6503,3
89 | 2f56fe44-b08d-4fad-9d06-e027eb237ab0,1
90 | 5ce97ea6-a008-4e8b-97ba-164fdbdf4a8c,2
91 | 54fc80f9-efb9-4266-9cfa-7c7d0660f672,6
92 | 26841c21-d3ed-4541-aa59-907e93cc38d2,22
93 | 45d5c757-2bc3-40be-8c5b-f2ebc3cd9572,1
94 | cd010b9d-266f-4df2-aa50-65eace32d430,3
95 | 1e4a2550-e308-41ac-a247-7299e6e2db31,6
96 | 8e52d0e3-d751-4a02-9f68-45be7161be19,4
97 | d010d149-5b92-4a3f-ab20-ec9ab8bd3391,7
98 | 78437d70-5ea5-4df8-a95b-aa46733519d2,0
99 | bbc39510-f8e5-4fe2-b65b-5dc4c5fe5fd7,1
100 | 7bdc0ac1-0248-4841-8465-5e1cd44860fc,17
101 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification
2 |
3 | 
4 |
5 | [[Homepage](https://um-arm-lab.github.io/Efficient-Eng-2-LTL/)] [[Paper](https://arxiv.org/abs/2303.08006)] [[Video](https://drive.google.com/file/d/14Sy5y76YglZ6X3Y3ZZBZZiMGBA9gME9G/view?usp=sharing)] [[Poster](https://drive.google.com/file/d/1j0aZoROb1EKC0oRYYBSwBIx4Xp8ElowN/view?usp=sharing)]
6 |
7 | > The associated repo for paper "Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification".
8 |
9 | ## Repo Structure
10 |
11 | - Root
12 | - datasets
13 | - [drone-planning](https://arxiv.org/abs/1905.12096)
14 | - [clean-up](http://www.roboticsproceedings.org/rss14/p67.html)
15 | - [pick-and-place](http://www.roboticsproceedings.org/rss14/p67.html)
16 | - augmentation
17 | - paraphrase with GPT-3
18 | - run
19 | - train the models
20 | - inference with constrained decoding
21 |
22 | The constrained decoding inference code is based on: [microsoft/semantic_parsing_with_constrained_lm](https://github.com/microsoft/semantic_parsing_with_constrained_lm)
23 |
24 | ## Reproduce the Results
25 |
26 | Following are the instructions for reproducing the result for our model. For baselines, come and check out [this link](https://github.com/UM-ARM-Lab/Efficient-Eng-2-LTL/issues/1).
27 | ### Environment Setup
28 |
29 | Install the dependencies:
30 |
31 | ```bash
32 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 # make sure the version is compatible with your cuda version
33 | pip install transformers datasets
34 | pip install sentencepiece
35 | pip install jsons appdirs blobfile cached-property httpx typer whoosh more_itertools
36 | pip install --upgrade protobuf==3.20.0
37 | ```
38 |
39 | Download BART-large model:
40 |
41 | ```bash
42 | cd ./run
43 | python ./semantic_parsing_with_constrained_lm/finetune/download_huggingface_lms.py
44 | ```
45 |
46 | ### Prepare the Dataset
47 |
48 | The processed dataset (with augmentation from LLM) in already included in the repo. This step is only needed if you want to reprocess the dataset.
49 |
50 | To actually process the raw dataset, you can follow the steps below:
51 |
52 | 1. Pre-process: In each of the three dataset folders, run all cells in "preprocess.ipynb" to generate the processed dataset. (the annotation result is included in the notebook).
53 | 2. Augmentation: For each of the three datasets, run all commands in "augment.ipynb" to generate the augmented dataset. Note that this step requires a GPT-3 API key.
54 | 3. Move to training folder: You then need to reformat the dataset and move it to the `run/semantic_parsing_with_constrained_lm/domains/ltl/data` folder. A script will be provided later to help you automate this process.
55 |
56 | ### Train
57 |
58 | In our paper, we use the [BART-large model](https://huggingface.co/facebook/bart-large) because it is efficient to fine-tune on a single GPU. Our proposed method can be easily applied to other potentially stronger language models like [T5-XXL](https://arxiv.org/abs/1910.10683) or [GPT-3](https://arxiv.org/abs/2005.14165).
59 |
60 | ```sh
61 | export PRETRAINED_MODEL_DIR=huggingface_models/bart-large
62 | export TRAINED_MODEL_DIR=trained_models/
63 |
64 | cd ./run
65 | DOMAIN=TODO # for example, DOMAIN=pick-syn-aug
66 | python -m semantic_parsing_with_constrained_lm.finetune.lm_finetune \
67 | --config-name semantic_parsing_with_constrained_lm.finetune.configs.emnlp_train_config \
68 | --exp-names ltl_${DOMAIN}_utterance
69 | ```
70 |
71 | Here DOMAIN determines which experiment to run.
72 | DOMAIN: {dataset_name}-{experiment_name}
73 |
74 | - dataset_name: {drone, cleanup, pick}
75 | - experiment_name:
76 | - syn-aug: synthetic with augmentation
77 | - syn: synthetic without augmentation
78 | - golden-cross0-split{0,1,2,3,4}: golden dataset with cross-validation
79 |
80 | ### Inference
81 |
82 | ```sh
83 | export PRETRAINED_MODEL_DIR=huggingface_models/bart-large
84 | export TRAINED_MODEL_DIR=trained_models/
85 |
86 | DOMAIN=TODO
87 |
88 | python -m semantic_parsing_with_constrained_lm.run_exp \
89 | --config-name semantic_parsing_with_constrained_lm.configs.ltl_config \
90 | --log-dir logs/ \
91 | --model Bart \
92 | --eval-split test-full \
93 | --exp-names "ltl_Bart_test-full_${DOMAIN}_constrained_utterance_train-0"
94 | ```
95 |
96 | The domain name is the same as the training step.
97 |
98 | ## Cite
99 | ```bibtex
100 | @article{pan2023data,
101 | title={Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification},
102 | author={Pan, Jiayi and Chou, Glen and Berenson, Dmitry},
103 | journal={arXiv preprint arXiv:2303.08006},
104 | year={2023}
105 | }
106 | ```
107 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/index/bm25_index.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import random
5 | import tempfile
6 | from typing import Callable, Generic, Iterable, List, Sequence, Tuple
7 |
8 | import whoosh.index
9 | from whoosh.fields import STORED, TEXT, SchemaClass
10 | from whoosh.qparser import OrGroup, QueryParser
11 |
12 | from semantic_parsing_with_constrained_lm.datum import DatumSub, FullDatumSub
13 | from semantic_parsing_with_constrained_lm.index import Candidate, DynamicIndex, Query
14 | from semantic_parsing_with_constrained_lm.model import DataRetriever
15 |
16 |
17 | class PromptSchema(SchemaClass):
18 | text = TEXT()
19 | key = STORED()
20 |
21 |
22 | class BM25Index(Generic[Query, Candidate], DynamicIndex[int, Query, Candidate]):
23 | def __init__(
24 | self, get_content: Callable[[Candidate], str], get_query: Callable[[Query], str]
25 | ):
26 | # TODO: custom tokenizer from ClampTokenizer
27 | # TODO: now indexed in a temp dir
28 | tmp_indexer_loc = tempfile.mkdtemp()
29 | self.index = whoosh.index.create_in(tmp_indexer_loc, schema=PromptSchema)
30 | self.get_content = get_content
31 | self.get_query = get_query
32 |
33 | @classmethod
34 | def create(
35 | cls,
36 | candidates: Iterable[Candidate],
37 | get_content: Callable[[Candidate], str],
38 | get_query: Callable[[Query], str],
39 | ) -> "BM25Index":
40 | index = BM25Index(get_content, get_query)
41 | with index.index.writer() as writer:
42 | for i, candidate in enumerate(candidates):
43 | writer.add_document(text=get_content(candidate), key=i)
44 | return index
45 |
46 | def add(self, candidates: Iterable[Candidate]):
47 | n = self.index.doc_count()
48 | with self.index.writer() as writer:
49 | for i, candidate in enumerate(candidates):
50 | writer.add_document(
51 | text=self.get_content(candidate), key=n + i
52 | ) # auto-increment key
53 |
54 | def search(self, query: Query, top_k: int = 10) -> List[Tuple[int, float]]:
55 | with self.index.searcher() as searcher:
56 | query_parser = QueryParser("text", schema=searcher.schema, group=OrGroup)
57 | q = query_parser.parse(self.get_query(query))
58 | results = searcher.search(q, limit=top_k)
59 | return [(result["key"], result.score) for result in results]
60 |
61 |
62 | class BM25Retriever(DataRetriever[FullDatumSub, DatumSub]):
63 | def __init__(
64 | self,
65 | train_data: Sequence[FullDatumSub],
66 | top_k: int = 20,
67 | best_first: bool = True,
68 | seed: int = 12345,
69 | ):
70 | self.index: BM25Index[DatumSub, FullDatumSub] = BM25Index.create(
71 | train_data,
72 | get_content=lambda c: c.natural, # type: ignore
73 | get_query=lambda q: q.natural, # type: ignore
74 | )
75 | self.data: List[FullDatumSub] = list(train_data)
76 | self.top_k = top_k
77 | self.best_first = best_first
78 | self.prng = random.Random(
79 | seed
80 | ) # a random number generator to ensure deterministic behavior
81 |
82 | def augment_with_random_samples(
83 | self, data: Sequence[FullDatumSub], retrieved_keys: Sequence[int]
84 | ) -> Sequence[FullDatumSub]:
85 |
86 | if len(retrieved_keys) < self.top_k:
87 | print(
88 | f"Could not retrieve {self.top_k} examples, got only {len(retrieved_keys)}"
89 | )
90 | keys_to_sample = sorted(set(range(len(data))).difference(retrieved_keys))
91 | sampled_keys = self.prng.sample(
92 | keys_to_sample,
93 | k=min(self.top_k - len(retrieved_keys), len(keys_to_sample)),
94 | )
95 | augmented_keys = list(retrieved_keys) + list(sampled_keys)
96 | print(f"Added samples to make it of size {len(augmented_keys)}")
97 | items = [data[k] for k in augmented_keys[: self.top_k]]
98 | else:
99 | items = [data[k] for k in retrieved_keys[: self.top_k]]
100 |
101 | return items if self.best_first else list(reversed(items))
102 |
103 | def add(self, data: Sequence[FullDatumSub]):
104 | self.index.add(data)
105 | self.data.extend(data)
106 |
107 | async def __call__(self, test_datum: DatumSub) -> Sequence[FullDatumSub]:
108 | results = self.augment_with_random_samples(
109 | data=self.data,
110 | retrieved_keys=[
111 | key for key, _ in self.index.search(test_datum, top_k=self.top_k)
112 | ], # score discarded
113 | )
114 | return results
115 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/configs/lib/benchclamp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | from pathlib import Path
6 |
7 | from semantic_parsing_with_constrained_lm.earley.cfg import load_grammar_from_directory
8 | from semantic_parsing_with_constrained_lm.datum import DatumSub
9 | from semantic_parsing_with_constrained_lm.decoding.earley_partial_parse import (
10 | GrammarTokenizerInfo,
11 | UTF8EarleyPartialParse,
12 | )
13 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import StartsWithSpacePartialParse
14 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import (
15 | BenchClampDataset,
16 | BenchClampDatasetConfig,
17 | )
18 | from semantic_parsing_with_constrained_lm.domains.lispress_v2.grammar import (
19 | create_partial_parse_builder as create_partial_parse_builder_lispress_v2,
20 | )
21 | from semantic_parsing_with_constrained_lm.domains.mtop.grammar import (
22 | create_partial_parse_builder as create_partial_parse_builder_mtop,
23 | )
24 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.grammar import (
25 | load_base_grammar,
26 | preprocessed_grammar_for_schema,
27 | )
28 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.schema import load_schemas
29 | from semantic_parsing_with_constrained_lm.model import PartialParseBuilder
30 | from semantic_parsing_with_constrained_lm.paths import BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE
31 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
32 |
33 | TEST_SUITE_PATH = Path("/mnt/my_input/test-suite-sql-eval")
34 | TEST_SUITE_DATABASE_PATH = Path("/mnt/my_input/test-suite-sql-eval/database/")
35 | SPIDER_DATABASE_PATH = Path("/mnt/my_input/Spider/database/")
36 | SPIDER_TABLES_FILE = Path("/mnt/my_input/Spider/tables.json")
37 | COSQL_DATABASE_PATH = Path("/mnt/my_input/CoSQL/database/")
38 | COSQL_TABLES_FILE = Path("/mnt/my_input/CoSQL/tables.json")
39 |
40 |
41 | def create_partial_parse_builder(
42 | constrained: bool, data_config: BenchClampDatasetConfig, tokenizer: ClampTokenizer
43 | ) -> PartialParseBuilder[DatumSub]:
44 | if constrained:
45 | domain_str = data_config.domain if data_config.domain is not None else ""
46 | if data_config.dataset_name in [
47 | BenchClampDataset.Spider.value,
48 | BenchClampDataset.CoSQL.value,
49 | ]:
50 | print("Loading database schemas ...")
51 | if data_config.dataset_name == BenchClampDataset.Spider.value:
52 | schemas = load_schemas(
53 | schemas_path=SPIDER_TABLES_FILE,
54 | db_path=SPIDER_DATABASE_PATH,
55 | )
56 | else:
57 | schemas = load_schemas(
58 | schemas_path=COSQL_TABLES_FILE, db_path=COSQL_DATABASE_PATH
59 | )
60 | print("Done")
61 |
62 | base_grammar = load_base_grammar()
63 | pre_grammars = {
64 | name: preprocessed_grammar_for_schema(db, base_grammar)
65 | for name, db in schemas.items()
66 | }
67 | grammar_tok_info = {
68 | name: GrammarTokenizerInfo.create(tokenizer, preprocessed_grammar, True)
69 | for name, preprocessed_grammar in pre_grammars.items()
70 | }
71 | partial_parse_builder = lambda datum: UTF8EarleyPartialParse.initial(
72 | grammar_tok_info[datum.schema_name], datum.natural # type: ignore
73 | )
74 |
75 | elif data_config.dataset_name in (
76 | BenchClampDataset.CalFlowV2.value,
77 | BenchClampDataset.TreeDST.value,
78 | ):
79 | partial_parse_builder = create_partial_parse_builder_lispress_v2(
80 | load_grammar_from_directory(
81 | os.path.join(
82 | BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE,
83 | data_config.dataset_name,
84 | domain_str,
85 | )
86 | ),
87 | tokenizer,
88 | )
89 | elif data_config.dataset_name == BenchClampDataset.MTOP.value:
90 | partial_parse_builder = create_partial_parse_builder_mtop(
91 | load_grammar_from_directory(
92 | os.path.join(
93 | BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE,
94 | data_config.dataset_name,
95 | domain_str,
96 | )
97 | ),
98 | tokenizer,
99 | )
100 | else:
101 | raise ValueError(f"{data_config.dataset_name} not supported")
102 | else:
103 | partial_parse = StartsWithSpacePartialParse(tokenizer)
104 | partial_parse_builder = lambda _: partial_parse
105 | return partial_parse_builder
106 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/train_model_setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import abc
5 | from abc import abstractmethod
6 | from dataclasses import dataclass
7 | from pathlib import Path
8 | from typing import Dict, List, Optional, Tuple
9 |
10 | import torch
11 | from transformers import (
12 | BartForConditionalGeneration,
13 | GPT2LMHeadModel,
14 | PreTrainedModel,
15 | T5ForConditionalGeneration,
16 | )
17 |
18 | from semantic_parsing_with_constrained_lm.lm import Seq2SeqSettings, Surround
19 | from semantic_parsing_with_constrained_lm.tokenization import (
20 | ClampTokenizer,
21 | GPT2ClampTokenizer,
22 | T5ClampTokenizer,
23 | )
24 |
25 |
26 | class TrainedModelNotFoundError(FileNotFoundError):
27 | pass
28 |
29 |
30 | @dataclass # type: ignore
31 | class ClampModelConfig(abc.ABC):
32 | model_id: str
33 | model_loc: Path
34 | device_map: Optional[Dict[int, List[int]]] = None
35 |
36 | @abstractmethod
37 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]:
38 | pass
39 |
40 | def maybe_parallelize(self, model: PreTrainedModel) -> None:
41 | if torch.cuda.is_available():
42 | if self.device_map is not None:
43 | print(f"Parallelizing model with {self.device_map}")
44 | model.parallelize(self.device_map)
45 | else:
46 | print("Entire model to GPU 0")
47 | model.to(torch.device("cuda:0"))
48 | else:
49 | model.to(torch.device("cpu"))
50 |
51 |
52 | class BartModelConfig(ClampModelConfig):
53 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]:
54 | if not self.model_loc.exists():
55 | raise TrainedModelNotFoundError(
56 | f"Model files not found in {self.model_loc}"
57 | )
58 | model = BartForConditionalGeneration.from_pretrained(self.model_loc)
59 | tokenizer = GPT2ClampTokenizer.from_pretrained(str(self.model_loc))
60 | seq2seq_settings = Seq2SeqSettings(
61 | input_surround=Surround(bos=[0], eos=[2], starts_with_space=True),
62 | output_surround=Surround(bos=[0], eos=[2], starts_with_space=True),
63 | decoder_start_token_id=2,
64 | )
65 | self.maybe_parallelize(model)
66 | model.eval()
67 | return model, tokenizer, seq2seq_settings
68 |
69 |
70 | class T5ModelConfig(ClampModelConfig):
71 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]:
72 | if not self.model_loc.exists():
73 | raise TrainedModelNotFoundError(
74 | f"Model files not found in {self.model_loc}"
75 | )
76 | print(f"Loading model from {self.model_loc}")
77 | model = T5ForConditionalGeneration.from_pretrained(self.model_loc)
78 | print("Done")
79 | tokenizer = T5ClampTokenizer.from_pretrained(str(self.model_loc))
80 | seq2seq_settings = Seq2SeqSettings(
81 | input_surround=Surround(bos=[], eos=[1], starts_with_space=True),
82 | output_surround=Surround(bos=[], eos=[1], starts_with_space=True),
83 | decoder_start_token_id=tokenizer.pad_token_id,
84 | )
85 | self.maybe_parallelize(model)
86 | model.eval()
87 | return model, tokenizer, seq2seq_settings
88 |
89 |
90 | class CodeT5ModelConfig(ClampModelConfig):
91 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]:
92 | if not self.model_loc.exists():
93 | raise TrainedModelNotFoundError(
94 | f"Model files not found in {self.model_loc}"
95 | )
96 | model = T5ForConditionalGeneration.from_pretrained(self.model_loc)
97 | tokenizer = GPT2ClampTokenizer.from_pretrained(str(self.model_loc))
98 | seq2seq_settings = Seq2SeqSettings(
99 | input_surround=Surround(bos=[1], eos=[2], starts_with_space=True),
100 | output_surround=Surround(bos=[1], eos=[2], starts_with_space=True),
101 | decoder_start_token_id=0,
102 | )
103 | self.maybe_parallelize(model)
104 | model.eval()
105 | return model, tokenizer, seq2seq_settings
106 |
107 |
108 | class GPT2ModelConfig(ClampModelConfig):
109 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]:
110 | if not self.model_loc.exists():
111 | raise TrainedModelNotFoundError(
112 | f"Model files not found in {self.model_loc}"
113 | )
114 | model = GPT2LMHeadModel.from_pretrained(self.model_loc)
115 | tokenizer = GPT2ClampTokenizer.from_pretrained(str(self.model_loc))
116 | seq2seq_settings = Seq2SeqSettings(
117 | input_surround=Surround(
118 | bos=[20490, 25], eos=[198], starts_with_space=True
119 | ), # bos: "Human:", eos: "\n"
120 | output_surround=Surround(
121 | bos=[34556, 25], eos=[198], starts_with_space=True
122 | ), # bos: "Computer:", eos: "\n"
123 | decoder_start_token_id=None,
124 | )
125 | self.maybe_parallelize(model)
126 | model.eval()
127 | return model, tokenizer, seq2seq_settings
128 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/lispress_v2/grammar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import functools
5 | import re
6 | from typing import Iterable, List, Optional, Set
7 |
8 | from semantic_parsing_with_constrained_lm.domains import dfa_grammar_utils
9 | from semantic_parsing_with_constrained_lm.domains.lispress_v2.lispress_exp import (
10 | BooleanExpr,
11 | CallExpr,
12 | DialogueV2,
13 | LambdaExpr,
14 | LetExpr,
15 | LispressExpr,
16 | LongExpr,
17 | NumberExpr,
18 | ReferenceExpr,
19 | StringExpr,
20 | TypeName,
21 | parse_fully_typed_lispress_v2,
22 | )
23 |
24 |
25 | def get_nt_from_type(type_name: TypeName) -> str:
26 | segments = (
27 | str(type_name)
28 | .replace(" ", " SP ")
29 | .replace("(", " LP ")
30 | .replace(")", " RP ")
31 | .replace(".", " DOT ")
32 | .split()
33 | )
34 | return "_".join(segments + ["NT"])
35 |
36 |
37 | def extract_grammar_rules(lispress_expr: LispressExpr) -> Set[str]:
38 | lhs = get_nt_from_type(lispress_expr.type) # type: ignore
39 | rules = set()
40 |
41 | if isinstance(lispress_expr, (NumberExpr, LongExpr, StringExpr, BooleanExpr)):
42 | pass
43 | elif isinstance(lispress_expr, ReferenceExpr):
44 | rules.add(f'{lhs} -> "{lispress_expr.var_name}"')
45 |
46 | elif isinstance(lispress_expr, LambdaExpr):
47 | rhs_items = [
48 | f'"(lambda (^{str(lispress_expr.var_type)} {lispress_expr.var_name}) "',
49 | get_nt_from_type(lispress_expr.main_expr.type), # type: ignore
50 | '")"',
51 | ]
52 | rhs = " ".join(rhs_items)
53 | rules.add(f"{lhs} -> {rhs}")
54 | rules.update(extract_grammar_rules(lispress_expr.main_expr))
55 |
56 | elif isinstance(lispress_expr, LetExpr):
57 | var_name_expr_nts = []
58 | for var_name, var_expr in lispress_expr.var_assignments:
59 | var_name_expr_nts.extend([f'"{var_name}"', get_nt_from_type(var_expr.type)]) # type: ignore
60 | rules.update(extract_grammar_rules(var_expr))
61 | var_name_expr_nts_str = ' " " '.join(var_name_expr_nts)
62 | rhs = f'"(let (" {var_name_expr_nts_str} ") " {get_nt_from_type(lispress_expr.main_expr.type)} ")"' # type: ignore
63 | rules.add(f"{lhs} -> {rhs}")
64 | rules.update(extract_grammar_rules(lispress_expr.main_expr))
65 |
66 | elif isinstance(lispress_expr, CallExpr):
67 | rhs_items: List[str] = []
68 | if lispress_expr.instantiation_type is not None:
69 | rhs_items.append(f'"^{lispress_expr.instantiation_type} "?')
70 |
71 | rhs_items.append(f'"{lispress_expr.name}"')
72 |
73 | for k, v in lispress_expr.args:
74 | rhs_items.extend([f'" :{k}"?', '" "', get_nt_from_type(v.type)]) # type: ignore
75 | rules.update(extract_grammar_rules(v))
76 |
77 | rhs = " ".join(rhs_items)
78 | rules.add(f'{lhs} -> "(" {rhs} ")"')
79 |
80 | return rules
81 |
82 |
83 | def extract_grammar(
84 | dataflow_dialogues: Iterable[DialogueV2],
85 | whitelisted_dialogue_ids: Optional[Set[str]] = None,
86 | ) -> Set[str]:
87 | grammar_rules = set()
88 | for dataflow_dialogue in dataflow_dialogues:
89 | if (
90 | whitelisted_dialogue_ids is not None
91 | and dataflow_dialogue.dialogue_id not in whitelisted_dialogue_ids
92 | ):
93 | continue
94 |
95 | for turn in dataflow_dialogue.turns:
96 | lispress_expr = parse_fully_typed_lispress_v2(turn.fully_typed_lispress)
97 | grammar_rules.update(extract_grammar_rules(lispress_expr))
98 | root_type_nt = get_nt_from_type(lispress_expr.type) # type: ignore
99 | grammar_rules.add(f'start -> " " {root_type_nt}')
100 |
101 | # find string literals
102 | for match in re.finditer(r'Path\.apply "([^"]*)"', turn.lispress):
103 | start = match.start(1)
104 | end = match.end(1)
105 | item = turn.lispress[start:end]
106 | # We use `repr` because the .cfg parser uses `ast.literal_eval`
107 | # to parse the strings, since that will handle backslash escape
108 | # sequences. Without `repr` the resulting grammar will have one
109 | # level of escaping removed.
110 | grammar_rules.add(f"path_literal -> {repr(item)}")
111 | grammar_rules.update(
112 | [
113 | 'Boolean_NT -> "true"',
114 | 'Boolean_NT -> "false"',
115 | r'String_NT -> "\"" (String_NT_content | path_literal | "output" | "place" | "start") "\""',
116 | # Lispress V2 string literals are JSON string literals, so we follow this grammar:
117 | # https://datatracker.ietf.org/doc/html/rfc8259#section-7
118 | r'String_NT_content -> ([[\u0020-\U0010FFFF]--[\u0022\u005C]] | "\\" (["\u005C/bfnrt] | u[0-9A-Fa-f]{4}))*',
119 | 'Number_NT -> ("0" | [1-9][0-9]*) ("." [0-9]+)?',
120 | 'Long_NT -> ("0" | [1-9][0-9]*) "L"',
121 | ]
122 | )
123 | return grammar_rules
124 |
125 |
126 | create_partial_parse_builder = functools.partial(
127 | dfa_grammar_utils.create_partial_parse_builder,
128 | utterance_nonterm_name="String_NT_content",
129 | )
130 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/sql_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import dataclasses
5 | import subprocess
6 | from collections import defaultdict
7 | from dataclasses import dataclass
8 | from typing import Dict, List, Optional, Sequence, Tuple
9 |
10 | from semantic_parsing_with_constrained_lm.datum import FullDatumSub
11 | from semantic_parsing_with_constrained_lm.domains.sql.sql_datum import SqlDatum
12 | from semantic_parsing_with_constrained_lm.eval import Metric
13 |
14 |
15 | @dataclass
16 | class SQLTestSuiteMatch(Metric[Sequence[str], FullDatumSub]):
17 | """
18 | Metric to evaluate SQL predictions. Uses the test-suite available here:
19 | https://github.com/taoyds/test-suite-sql-eval. To use this metric, clone this repo to a local
20 | directory and set `test_suite_path` to that directory.
21 | """
22 |
23 | db_path: str
24 | test_suite_path: str
25 | table_file: str
26 | log_dir: str
27 | schema_map: Dict[str, str] = dataclasses.field(init=False)
28 | predictions_map: Dict[Tuple[str, int], str] = dataclasses.field(init=False)
29 | gold_map: Dict[Tuple[str, int], str] = dataclasses.field(init=False)
30 | dialogue_to_turn_indices_map: Dict[str, List[int]] = dataclasses.field(init=False)
31 |
32 | def __post_init__(self):
33 | self.reset()
34 |
35 | def _is_correct(self, pred: str, target: SqlDatum) -> bool:
36 | """Can be overridden by child classes."""
37 | return pred == target.canonical
38 |
39 | def update(
40 | self, preds: Sequence[str], target: SqlDatum
41 | ) -> Dict[str, Optional[str]]:
42 | schema_name = target.schema_name
43 | self.schema_map[target.dialogue_id] = schema_name # type: ignore
44 | self.predictions_map[(target.dialogue_id, target.turn_part_index)] = ( # type: ignore
45 | preds[0] if len(preds) > 0 else "dummy"
46 | )
47 | self.gold_map[(target.dialogue_id, target.turn_part_index)] = target.canonical # type: ignore
48 | self.dialogue_to_turn_indices_map[target.dialogue_id].append( # type: ignore
49 | target.turn_part_index # type: ignore
50 | )
51 | return {}
52 |
53 | def compute(self, gold_file=None, pred_file=None) -> Dict[str, float]:
54 | # Run test suite using subprocess
55 | is_interaction = any(
56 | [
57 | len(turn_indices) > 1
58 | for _, turn_indices in self.dialogue_to_turn_indices_map.items()
59 | ]
60 | )
61 | if gold_file is None and pred_file is None:
62 | gold_file = self.log_dir + "/gold.txt"
63 | pred_file = self.log_dir + "/pred.txt"
64 | with open(gold_file, "w") as fp_gold, open(pred_file, "w") as fp_pred:
65 | for dialogue_id in self.dialogue_to_turn_indices_map:
66 | for turn_index in self.dialogue_to_turn_indices_map[dialogue_id]:
67 | gold = self.gold_map[(dialogue_id, turn_index)]
68 | if gold.count(")") == 1 and gold.count("(") == 0:
69 | gold = gold.replace(")", "")
70 | if "faculty_participates_in" in gold:
71 | gold = gold.replace(
72 | "faculty_participates_in", "Faculty_participates_in"
73 | )
74 | fp_gold.write(
75 | gold.replace(" . ", ".")
76 | + "\t"
77 | + self.schema_map[dialogue_id]
78 | + "\n"
79 | )
80 | fp_pred.write(
81 | self.predictions_map[(dialogue_id, turn_index)].replace(
82 | " . ", "."
83 | )
84 | + "\n"
85 | )
86 |
87 | if is_interaction:
88 | fp_gold.write("\n")
89 | fp_pred.write("\n")
90 |
91 | process = subprocess.run(
92 | [
93 | "python3",
94 | "evaluation.py",
95 | "--gold",
96 | gold_file,
97 | "--pred",
98 | pred_file,
99 | "--db",
100 | self.db_path,
101 | "--table",
102 | self.table_file,
103 | "--etype",
104 | "all",
105 | ],
106 | cwd=self.test_suite_path,
107 | capture_output=True,
108 | text=True,
109 | check=True,
110 | )
111 | print("stdout:", process.stdout)
112 | print("stderr:", process.stderr)
113 | execution_acc = 0.0
114 | for line in process.stdout.split("\n"):
115 | if line.startswith("execution"):
116 | execution_acc = float(line.split()[5].strip())
117 | break
118 | result = {"execution_acc": execution_acc}
119 | print(result)
120 | return result
121 |
122 | def reset(self) -> None:
123 | self.predictions_map = {}
124 | self.gold_map = {}
125 | self.schema_map = {}
126 | self.dialogue_to_turn_indices_map = defaultdict(list)
127 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/fit_max_steps.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import collections
5 | import itertools
6 | import random
7 | from dataclasses import dataclass
8 | from typing import Dict, Iterator, List, Tuple
9 |
10 | import numpy as np
11 |
12 |
13 | @dataclass(frozen=True)
14 | class Quantiles:
15 | # THe quantile used to find the intercept.
16 | # In effect, we sort the list of lengths, and take
17 | # `lengths[int(len(lengths) * intercept_quantile)]`
18 | # to use as the intercept.
19 | intercept_quantile: float
20 | # The quantile used for the slope.
21 | # For each datum, we compute the slope that would be needed so that the
22 | # predicted length matches the gold length.
23 | slope_quantile: float
24 |
25 |
26 | @dataclass
27 | class Result:
28 | # Number of test instances that had predicted lengths smaller than gold lengths.
29 | num_unreachable: int
30 | # predicted length - gold length
31 | surplus_steps: List[int]
32 |
33 |
34 | def fit(pairs: List[Tuple[int, int]]) -> Iterator[Tuple[Quantiles, float, float]]:
35 | # 0, 0.01, 0.02, ..., 0.99, 1.00
36 | intercept_quantiles = np.linspace(0, 1, 101)
37 | # 1, 0.99, 0.98, ..., 0.90
38 | slope_quantiles = np.linspace(1, 0.9, 11)
39 |
40 | intercept_by_quantile = dict(
41 | zip(
42 | intercept_quantiles, np.quantile([o for _, o in pairs], intercept_quantiles)
43 | )
44 | )
45 | intercept_by_quantile[-1] = 0.0
46 |
47 | for intercept_quantile, intercept in intercept_by_quantile.items():
48 | slopes = [(o - intercept) / i for i, o in pairs]
49 | max_slopes = np.quantile(slopes, slope_quantiles)
50 | for slope_quantile, slope in zip(slope_quantiles, max_slopes):
51 | yield Quantiles(intercept_quantile, slope_quantile), intercept, slope
52 |
53 |
54 | def cross_validation_fit(
55 | pairs: List[Tuple[int, int]], num_splits: int
56 | ) -> Dict[Quantiles, List[Result]]:
57 | all_results: Dict[Quantiles, List[Result]] = collections.defaultdict(list)
58 |
59 | # Perform n-fold cross-validation
60 | test_set_size = len(pairs) / num_splits # This is a real number
61 | random.Random(0).shuffle(pairs)
62 | for test_start, test_end in zip(
63 | np.arange(0, len(pairs), test_set_size),
64 | np.arange(test_set_size, len(pairs), test_set_size),
65 | ):
66 | # Round down first so that we get an integer size for the test set
67 | test_start = int(test_start)
68 | test_end = int(test_end)
69 |
70 | train_data = pairs[:test_start] + pairs[test_end:]
71 | test_data = pairs[test_start:test_end]
72 |
73 | for quantiles, intercept, slope in fit(train_data):
74 | predicted = [int(i * slope + intercept) for i, _ in test_data]
75 | num_missed = sum(p < o for p, (_, o) in zip(predicted, test_data))
76 | surplus_steps = [p - o for p, (_, o) in zip(predicted, test_data)]
77 | all_results[quantiles].append(Result(num_missed, surplus_steps))
78 |
79 | return all_results
80 |
81 |
82 | def filter_fit(
83 | all_results: Dict[Quantiles, List[Result]], max_unreachable: int
84 | ) -> List[Tuple[Quantiles, int, float]]:
85 | filtered_results: List[Tuple[Quantiles, int, float]] = []
86 | for q, results in all_results.items():
87 | total_unreachable = sum(r.num_unreachable for r in results)
88 | if total_unreachable <= max_unreachable:
89 | mean_surplus_steps = np.average(
90 | list(itertools.chain.from_iterable(r.surplus_steps for r in results))
91 | )
92 | filtered_results.append((q, total_unreachable, mean_surplus_steps))
93 | filtered_results.sort(key=lambda x: x[2])
94 |
95 | return filtered_results
96 |
97 |
98 | def compute_and_print_fit(
99 | pairs: List[Tuple[int, int]], num_splits: int, max_unreachable: int
100 | ) -> Tuple[float, float]:
101 | all_results: Dict[Quantiles, List[Result]] = cross_validation_fit(pairs, num_splits)
102 |
103 | filtered_results: List[Tuple[Quantiles, int, float]] = filter_fit(
104 | all_results, max_unreachable
105 | )
106 |
107 | print(f"Best params for total unreachable < {max_unreachable}:")
108 | for q, total_unreachable, mean_surplus_steps in filtered_results[:5]:
109 | print(
110 | f"Intercept quantile {q.intercept_quantile:.2f}, "
111 | f"slope quantile {q.slope_quantile:.2f}: "
112 | f"num unreachable = {total_unreachable}, "
113 | f"mean surplus steps = {mean_surplus_steps:.3g}"
114 | )
115 |
116 | min_surplus_steps = filtered_results[0][2]
117 | all_params_with_min_surplus_steps = [
118 | q
119 | for q, _, mean_surplus_steps in filtered_results
120 | if mean_surplus_steps == min_surplus_steps
121 | ]
122 | best_params = min(
123 | all_params_with_min_surplus_steps,
124 | key=lambda q: (q.intercept_quantile, q.slope_quantile),
125 | )
126 |
127 | if best_params.intercept_quantile == -1:
128 | intercept = 0
129 | else:
130 | intercept = np.quantile([o for _, o in pairs], best_params.intercept_quantile)
131 | slope = np.quantile(
132 | [(o - intercept) / i for i, o in pairs], best_params.slope_quantile
133 | )
134 |
135 | print()
136 | print(f"Final results: intercept = {intercept}, slope = {slope}")
137 | return (intercept, slope)
138 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/rule.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import json
5 | from abc import ABC
6 | from dataclasses import dataclass, replace
7 | from typing import Dict, Set, Tuple, cast
8 |
9 | from semantic_parsing_with_constrained_lm.scfg.parser.token import (
10 | EmptyToken,
11 | NonterminalToken,
12 | OptionableSCFGToken,
13 | SCFGToken,
14 | TerminalToken,
15 | )
16 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Expansion
17 |
18 | MAYBE_PREFIX = "maybe__"
19 |
20 |
21 | @dataclass(frozen=True)
22 | class Rule(ABC):
23 | lhs: str
24 |
25 |
26 | @dataclass(frozen=True)
27 | class PlanRule(Rule):
28 | lhs: str
29 | rhs: Expansion
30 |
31 |
32 | @dataclass(frozen=True)
33 | class SyncRule(Rule):
34 | lhs: str
35 | utterance_rhss: Tuple[Expansion, ...]
36 | plan_rhs: Expansion
37 |
38 |
39 | @dataclass(frozen=True)
40 | class UtteranceRule(Rule):
41 | lhs: str
42 | utterance_rhss: Tuple[Expansion, ...]
43 |
44 |
45 | # helpers
46 | def term(s: str, optional=False) -> TerminalToken:
47 | return TerminalToken(underlying=json.dumps(s), optional=optional)
48 |
49 |
50 | def nonterm(s: str, optional=False) -> NonterminalToken:
51 | return NonterminalToken(underlying=s, optional=optional)
52 |
53 |
54 | def mirrored_rule(lhs: str, rhs: Expansion) -> SyncRule:
55 | """Creates a SyncRule with identical utterance and plan expansions."""
56 | return SyncRule(lhs=lhs, utterance_rhss=(rhs,), plan_rhs=rhs)
57 |
58 |
59 | def expand_optionals(rule: Rule) -> Set[Rule]:
60 | """
61 | For each optional token `t` in rule, creates a fresh NT that expands to `t` or epsilon.
62 | """
63 |
64 | def clean(s: str) -> str:
65 | return "".join(c if c.isidentifier() else f"_chr{ord(c)}_" for c in s)
66 |
67 | def mk_optional_nt(s: OptionableSCFGToken) -> NonterminalToken:
68 | nt_or_t = "nt" if isinstance(s, NonterminalToken) else "t"
69 | return nonterm(f"{MAYBE_PREFIX}_{nt_or_t}_{clean(s.render())}")
70 |
71 | utterance_rhss = (
72 | rule.utterance_rhss if isinstance(rule, (SyncRule, UtteranceRule)) else ()
73 | )
74 | plan_rhss = (
75 | (rule.plan_rhs,)
76 | if isinstance(rule, SyncRule)
77 | else (rule.rhs,)
78 | if isinstance(rule, PlanRule)
79 | else ()
80 | )
81 | all_utterance_optionals: Set[OptionableSCFGToken] = {
82 | s
83 | for rhs in utterance_rhss
84 | for s in rhs
85 | if isinstance(s, OptionableSCFGToken) and s.optional
86 | }
87 | all_plan_optionals: Set[OptionableSCFGToken] = {
88 | s
89 | for rhs in plan_rhss
90 | for s in rhs
91 | if isinstance(s, OptionableSCFGToken) and s.optional
92 | }
93 | sync_optionals: Dict[SCFGToken, SCFGToken] = {
94 | s: mk_optional_nt(s)
95 | for s in all_utterance_optionals.intersection(all_plan_optionals)
96 | }
97 | utterance_only_optionals: Dict[SCFGToken, SCFGToken] = {
98 | s: mk_optional_nt(s)
99 | for s in all_utterance_optionals.difference(all_plan_optionals)
100 | }
101 | plan_only_optionals: Dict[SCFGToken, SCFGToken] = {
102 | s: mk_optional_nt(s)
103 | for s in all_plan_optionals.difference(all_utterance_optionals)
104 | }
105 |
106 | all_optionals: Dict[SCFGToken, SCFGToken] = {
107 | **utterance_only_optionals,
108 | **plan_only_optionals,
109 | **sync_optionals,
110 | }
111 | new_sync_rules: Set[Rule] = {
112 | r
113 | for s, nt in sync_optionals.items()
114 | for non_opt_s in [(replace(s, optional=False),)]
115 | for r in [
116 | mirrored_rule(nt.value, non_opt_s),
117 | mirrored_rule(nt.value, (EmptyToken(),)),
118 | ]
119 | }
120 | new_utt_rules: Set[Rule] = {
121 | r
122 | for s, nt in utterance_only_optionals.items()
123 | for non_opt_s in [(replace(s, optional=False),)]
124 | for r in [
125 | UtteranceRule(lhs=nt.value, utterance_rhss=(non_opt_s,)),
126 | UtteranceRule(lhs=nt.value, utterance_rhss=((EmptyToken(),),)),
127 | ]
128 | }
129 | new_plan_rules: Set[Rule] = {
130 | r
131 | for s, nt in plan_only_optionals.items()
132 | for non_opt_s in [(replace(s, optional=False),)]
133 | for r in [
134 | PlanRule(lhs=nt.value, rhs=non_opt_s),
135 | PlanRule(lhs=nt.value, rhs=(EmptyToken(),)),
136 | ]
137 | }
138 | new_maybe_rules: Set[Rule] = new_utt_rules | new_plan_rules | new_sync_rules
139 |
140 | def transform_rhs(rhs: Expansion) -> Expansion:
141 | return tuple(all_optionals.get(t, t) for t in rhs)
142 |
143 | new_main_rule: Rule
144 | if isinstance(rule, SyncRule):
145 | new_main_rule = SyncRule(
146 | lhs=rule.lhs,
147 | utterance_rhss=tuple(transform_rhs(rhs) for rhs in rule.utterance_rhss),
148 | plan_rhs=transform_rhs(rule.plan_rhs),
149 | )
150 | elif isinstance(rule, UtteranceRule):
151 | new_main_rule = UtteranceRule(
152 | lhs=rule.lhs,
153 | utterance_rhss=tuple(transform_rhs(rhs) for rhs in rule.utterance_rhss),
154 | )
155 | else:
156 | assert isinstance(rule, PlanRule), rule
157 | new_main_rule = PlanRule(lhs=rule.lhs, rhs=transform_rhs(rule.rhs))
158 | # pyright can't figure out that `new_main_rule` is a Rule
159 | return {cast(Rule, new_main_rule)} | new_maybe_rules
160 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/create_benchclamp_splits.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import random
5 | from collections import defaultdict
6 | from pathlib import Path
7 | from typing import Callable, Dict, List, Optional
8 |
9 | import torch
10 | from dataflow.core.io_utils import save_jsonl_file
11 |
12 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum, Datum, FullDatum
13 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse
14 | from semantic_parsing_with_constrained_lm.tokenization import GPT2ClampTokenizer
15 |
16 |
17 | def can_force_decode(
18 | clamp_tokenizer: GPT2ClampTokenizer,
19 | partial_parse_builder: Callable[[Datum], PartialParse],
20 | datum: FullDatum,
21 | ) -> bool:
22 | token_ids = clamp_tokenizer.encode(" " + datum.canonical)
23 | partial_parse = partial_parse_builder(datum) # type: ignore
24 | for i, token_id in enumerate(token_ids):
25 | next_tokens, can_end = partial_parse.allowed_next(
26 | torch.tensor([token_id]), top_k=1
27 | )
28 | if token_id not in next_tokens: # type: ignore
29 | print()
30 | print(f"Input: {repr(datum.natural)}")
31 | print(f"Output: {repr(datum.canonical)}")
32 | print(f"Prefix: {repr(clamp_tokenizer.decode(token_ids[:i]))}")
33 | print(f"Rejected: {clamp_tokenizer.decode([token_id])}")
34 | return False
35 | partial_parse = partial_parse.append(token_id)
36 | next_tokens, can_end = partial_parse.allowed_next(torch.tensor([0]), top_k=1)
37 | if not can_end:
38 | print()
39 | print(f"Input: {repr(datum.natural)}")
40 | print(f"Output: {repr(datum.canonical)}")
41 | print("Ending not allowed here")
42 | return False
43 |
44 | return True
45 |
46 |
47 | def gather_subset(data: List[BenchClampDatum], size: int) -> List[BenchClampDatum]:
48 | dialogue_id_to_turn_count: Dict[str, int] = defaultdict(int)
49 | for datum in data:
50 | dialogue_id_to_turn_count[datum.dialogue_id] += 1 # type: ignore
51 |
52 | dialogue_ids = sorted(dialogue_id_to_turn_count.keys())
53 | random.shuffle(dialogue_ids)
54 | selected_dialogue_ids = []
55 | num_turns_covered = 0
56 | for dialogue_id in dialogue_ids:
57 | selected_dialogue_ids.append(dialogue_id)
58 | num_turns_covered += dialogue_id_to_turn_count[dialogue_id]
59 | if num_turns_covered >= size:
60 | break
61 |
62 | if num_turns_covered < size:
63 | print(f"Not enough data to create subset of size {size}")
64 |
65 | selected_dialogue_ids_set = set(selected_dialogue_ids)
66 | return [datum for datum in data if datum.dialogue_id in selected_dialogue_ids_set]
67 |
68 |
69 | def create_benchclamp_splits(
70 | train_data: List[BenchClampDatum],
71 | dev_data: List[BenchClampDatum],
72 | test_data: Optional[List[BenchClampDatum]],
73 | output_dir: Path,
74 | ):
75 | """
76 | Sample splits for BenchClamp experiments.
77 | 1. 5 low data train splits of size 500, single dev set of size 50
78 | 2. 3 medium data train splits of size 5000, single dev set of size 500
79 | 3. Full data split. Reuses dev set for medium data.
80 | 4. Single test split of size 2000.
81 | """
82 | random.seed(0)
83 | if test_data is None:
84 | train_dialogue_ids = sorted({datum.dialogue_id for datum in train_data}) # type: ignore
85 | random.shuffle(train_dialogue_ids)
86 | test_data = dev_data
87 | num_train_dialogues = len(train_dialogue_ids)
88 | dev_dialogue_ids = set(train_dialogue_ids[: int(0.1 * num_train_dialogues)])
89 | dev_data = [
90 | datum for datum in train_data if datum.dialogue_id in dev_dialogue_ids
91 | ]
92 | train_data = [
93 | datum for datum in train_data if datum.dialogue_id not in dev_dialogue_ids
94 | ]
95 |
96 | print(
97 | f"Input sizes for creating benchclamp splits: "
98 | f"Train {len(train_data)}, Dev {len(dev_data)}, Test {len(test_data)}"
99 | )
100 | train_turn_ids = [
101 | (datum.dialogue_id, datum.turn_part_index) for datum in train_data
102 | ]
103 | dev_turn_ids = [(datum.dialogue_id, datum.turn_part_index) for datum in dev_data]
104 | test_turn_ids = [(datum.dialogue_id, datum.turn_part_index) for datum in test_data]
105 | assert (
106 | len(set(train_turn_ids)) == len(train_turn_ids)
107 | and len(set(dev_turn_ids)) == len(dev_turn_ids)
108 | and len(set(test_turn_ids)) == len(test_turn_ids)
109 | ), "Multiple data points have same data id, make sure all input data have unique data ids."
110 |
111 | output_dir.mkdir(parents=True, exist_ok=True)
112 | save_jsonl_file(test_data, str(output_dir / "test_all.jsonl"))
113 | save_jsonl_file(gather_subset(test_data, 2000), str(output_dir / "test.jsonl"))
114 |
115 | save_jsonl_file(dev_data, str(output_dir / "dev_all.jsonl"))
116 | save_jsonl_file(gather_subset(dev_data, 500), str(output_dir / "dev_medium.jsonl"))
117 | save_jsonl_file(gather_subset(dev_data, 50), str(output_dir / "dev_low.jsonl"))
118 |
119 | save_jsonl_file(train_data, str(output_dir / "train_all.jsonl"))
120 | for split_size_category, num_splits, train_split_size in [
121 | ("low", 5, 500),
122 | ("medium", 3, 5000),
123 | ]:
124 | for split_id in range(num_splits):
125 | train_subset_file = (
126 | output_dir / f"train_{split_size_category}_{split_id}.jsonl"
127 | )
128 | train_subset = gather_subset(train_data, train_split_size)
129 | save_jsonl_file(train_subset, str(train_subset_file))
130 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/ltl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import json
5 | from dataclasses import dataclass
6 | from enum import Enum
7 | from typing import Callable, Dict, List
8 |
9 | from blobfile import BlobFile
10 |
11 | from semantic_parsing_with_constrained_lm.util.trie import Trie
12 | from semantic_parsing_with_constrained_lm.util.types import StrPath
13 | from semantic_parsing_with_constrained_lm.datum import Datum, FullDatum
14 | from semantic_parsing_with_constrained_lm.decoding.trie_partial_parse import TriePartialParse
15 |
16 | from semantic_parsing_with_constrained_lm.eval import TopKExactMatch
17 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
18 |
19 | # NOTE: get rid of the catflow/dataflow dependency
20 | from appdirs import user_cache_dir
21 | CACHE_DIR = user_cache_dir("semantic_parsing_as_constrained_lm")
22 |
23 |
24 | class LTLOutputType(str, Enum):
25 | Utterance = "utterance"
26 | MeaningRepresentation = "meaningRepresentation"
27 |
28 |
29 | @dataclass
30 | class TopKDenotationMatch(TopKExactMatch[FullDatum]):
31 | canonical_to_denotation: Dict[str, str]
32 |
33 | def _is_correct(self, pred: str, datum: FullDatum) -> bool:
34 | target = datum.canonical
35 | pred_denotation = self.canonical_to_denotation.get(pred)
36 | target_denotation = self.canonical_to_denotation.get(target, None)
37 | if pred_denotation is None and target_denotation is None:
38 | return pred == target
39 | else:
40 | return pred_denotation == target_denotation
41 |
42 |
43 | @dataclass
44 | class LTLPieces:
45 | train_data: List[FullDatum]
46 | test_data: List[FullDatum]
47 | partial_parse_builder: Callable[[Datum], TriePartialParse]
48 | denotation_metric: TopKDenotationMatch
49 | max_length: int
50 |
51 | @staticmethod
52 | def from_dir(
53 | tokenizer: ClampTokenizer,
54 | root_dir: StrPath,
55 | domain: str,
56 | is_dev: bool,
57 | k: int,
58 | output_type: LTLOutputType = LTLOutputType.Utterance,
59 | simplify_logical_forms=False,
60 | prefix_with_space=False,
61 | ) -> "LTLPieces":
62 | data_pieces = LTLDataPieces.from_dir(
63 | root_dir, domain, is_dev, output_type, simplify_logical_forms
64 | )
65 | decoder_pieces = LTLDecoderPieces.create(
66 | data_pieces, tokenizer, k, prefix_with_space
67 | )
68 |
69 | return LTLPieces(
70 | data_pieces.train_data,
71 | data_pieces.test_data,
72 | # https://github.com/python/mypy/issues/5485
73 | decoder_pieces.partial_parse_builder, # type: ignore
74 | decoder_pieces.denotation_metric,
75 | decoder_pieces.max_length,
76 | )
77 |
78 |
79 | @dataclass
80 | class LTLDataPieces:
81 | train_data: List[FullDatum]
82 | test_data: List[FullDatum]
83 | target_output_to_denotation: Dict[str, str]
84 |
85 | @staticmethod
86 | def from_dir(
87 | root_dir: StrPath,
88 | domain: str,
89 | is_dev: bool,
90 | output_type: LTLOutputType = LTLOutputType.MeaningRepresentation,
91 | simplify_logical_forms: bool = False,
92 | ) -> "LTLDataPieces":
93 | with BlobFile(str(root_dir) + f"/{domain}.canonical.json") as bf:
94 | canonical_data = json.load(bf)
95 |
96 | if output_type == LTLOutputType.Utterance:
97 | target_output_to_denotation = {
98 | k: "DO NOT NEED" for k, v in canonical_data.items()
99 | }
100 | datum_key = "canonical"
101 | else:
102 | raise ValueError(output_type)
103 |
104 | train_data, test_data = [
105 | [
106 | FullDatum(
107 | dialogue_id=f"{dataset_name}-{i}",
108 | turn_part_index=None,
109 | agent_context=None,
110 | natural=d["natural"],
111 | canonical=d[datum_key]
112 | if simplify_logical_forms
113 | else d[datum_key],
114 | )
115 | for i, line in enumerate(
116 | BlobFile(path, streaming=False, cache_dir=CACHE_DIR)
117 | )
118 | for d in [json.loads(line)]
119 | ]
120 | for dataset_name, path in (
121 | (
122 | "train",
123 | f"{root_dir}/{domain}.train.jsonl",
124 | ),
125 | ("eval", f"{root_dir}/{domain}.test.jsonl"),
126 | )
127 | ]
128 |
129 | return LTLDataPieces(train_data, test_data, target_output_to_denotation)
130 |
131 |
132 | @dataclass
133 | class LTLDecoderPieces:
134 | data_pieces: LTLDataPieces
135 | partial_parse_builder: Callable[[Datum], TriePartialParse]
136 | denotation_metric: TopKDenotationMatch
137 | max_length: int
138 |
139 | @staticmethod
140 | def create(
141 | data_pieces: LTLDataPieces,
142 | tokenizer: ClampTokenizer,
143 | k: int,
144 | prefix_with_space: bool = False,
145 | ) -> "LTLDecoderPieces":
146 | if prefix_with_space:
147 | canonical_trie = Trie(
148 | tokenizer.encode(" " + canon)
149 | for canon in data_pieces.target_output_to_denotation
150 | )
151 | else:
152 | canonical_trie = Trie(
153 | tokenizer.encode(canon)
154 | for canon in data_pieces.target_output_to_denotation
155 | )
156 |
157 | def partial_parse_builder(_): return TriePartialParse(canonical_trie)
158 |
159 | denotation_metric = TopKDenotationMatch(
160 | k, data_pieces.target_output_to_denotation
161 | )
162 | max_length = max(len(x) for x in canonical_trie)
163 |
164 | return LTLDecoderPieces(
165 | data_pieces, partial_parse_builder, denotation_metric, max_length
166 | )
167 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/scfg/parser/parse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from pathlib import Path
5 | from typing import Any, List, Tuple, Union, cast
6 |
7 | from lark import Lark, Transformer, v_args
8 | from lark.exceptions import UnexpectedEOF # type: ignore[attr-defined]
9 |
10 | from semantic_parsing_with_constrained_lm.scfg.parser.macro import Macro
11 | from semantic_parsing_with_constrained_lm.scfg.parser.rule import (
12 | PlanRule,
13 | Rule,
14 | SyncRule,
15 | UtteranceRule,
16 | mirrored_rule,
17 | )
18 | from semantic_parsing_with_constrained_lm.scfg.parser.token import (
19 | EmptyToken,
20 | MacroToken,
21 | NonterminalToken,
22 | OptionableSCFGToken,
23 | RegexToken,
24 | SCFGToken,
25 | TerminalToken,
26 | )
27 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Expansion
28 |
29 |
30 | def parse_string(parser: Lark, string: str) -> Any: # -> Union[Rule, Macro]:
31 | """
32 | Parse a string into a Rule or an Expansion.
33 |
34 | We annotate the return type as Any because this can return a Rule, Macro, or Expansion, and
35 | it seems silly to do a cast at each call site.
36 | """
37 | try:
38 | return TreeToRule().transform(parser.parse(string))
39 | except UnexpectedEOF as e:
40 | raise Exception(f"Line could not be parsed: {string}") from e
41 |
42 |
43 | class TreeToRule(Transformer):
44 | @v_args(inline=True)
45 | def start(self, arg) -> Union[Macro, Rule]:
46 | return arg
47 |
48 | @v_args(inline=True)
49 | def start_for_test(self, arg) -> Union[Macro, Expansion, Rule]:
50 | return arg
51 |
52 | @v_args(inline=True)
53 | def terminal(self, underlying) -> TerminalToken:
54 | return TerminalToken(underlying.value, optional=False)
55 |
56 | @v_args(inline=True)
57 | def optional_terminal(self, underlying) -> TerminalToken:
58 | return TerminalToken(underlying.value, optional=True)
59 |
60 | @v_args(inline=True)
61 | def nonterminal(self, name_token) -> NonterminalToken:
62 | return NonterminalToken(name_token.value, optional=False)
63 |
64 | @v_args(inline=True)
65 | def optional_nonterminal(self, name_token) -> NonterminalToken:
66 | return NonterminalToken(name_token.value, optional=True)
67 |
68 | @v_args(inline=True)
69 | def empty(self) -> EmptyToken:
70 | return EmptyToken()
71 |
72 | @v_args(inline=True)
73 | def regex(self, arg) -> RegexToken:
74 | return RegexToken(arg.value, optional=False, prefix="")
75 |
76 | def plan_expansion(self, args) -> Expansion:
77 | return tuple(args)
78 |
79 | def utterance_expansion(self, args) -> Expansion:
80 | return tuple(args)
81 |
82 | def utterance_expansions(self, no_macro_expansions) -> Tuple[Expansion, ...]:
83 | return tuple(no_macro_expansions)
84 |
85 | @v_args(inline=True)
86 | def token(self, arg: SCFGToken) -> SCFGToken:
87 | return arg
88 |
89 | @v_args(inline=True)
90 | def rule(self, name_token) -> str:
91 | return name_token.value
92 |
93 | @v_args(inline=True)
94 | def sync_rule(
95 | self, lhs: str, expansions: List[Expansion], expansion: Expansion
96 | ) -> SyncRule:
97 | return SyncRule(lhs=lhs, utterance_rhss=tuple(expansions), plan_rhs=expansion)
98 |
99 | @v_args(inline=True)
100 | def mirrored_rule(self, lhs: str, rhs: Expansion) -> SyncRule:
101 | return mirrored_rule(lhs, rhs)
102 |
103 | @v_args(inline=True)
104 | def utterance_rule(self, rule, expansions) -> UtteranceRule:
105 | return UtteranceRule(rule, tuple(expansions))
106 |
107 | @v_args(inline=True)
108 | def macro_rule(self, macro_def, expansion) -> Macro:
109 | return Macro(macro_def[0], macro_def[1], expansion)
110 |
111 | def macro_def(self, args) -> Tuple[str, Tuple[str, ...]]:
112 | return cast(str, args[0].value), tuple(cast(str, a.value) for a in args[1:])
113 |
114 | def macro_apply(self, args) -> MacroToken:
115 | return MacroToken(args[0].value, tuple(args[1:]))
116 |
117 |
118 | def get_scfg_parser(start_symbol: str = "start") -> Lark:
119 | """
120 | Get a parser based on the SCFG grammar. The start rule that gets appended to the grammar
121 | at the end depends on whether we are testing or not. If we are testing, then we want to be
122 | able to parse expansions outside of rules so that in our tests, we don't have to write
123 | lists of tokens.
124 | """
125 | scfg_grammar_path = Path(__file__).parent / "scfg_grammar.lark"
126 | scfg_grammar: str
127 | with open(scfg_grammar_path, "r") as cf_grammar_file:
128 | scfg_grammar = cf_grammar_file.read()
129 |
130 | # Type ignoring because mypy doesn't play well with Lark.
131 | return Lark(scfg_grammar, ambiguity="explicit", start=start_symbol) # type: ignore
132 |
133 |
134 | # RENDERING
135 |
136 |
137 | def render_token(token: SCFGToken) -> str:
138 | if isinstance(token, EmptyToken):
139 | return "#e"
140 | elif isinstance(token, OptionableSCFGToken):
141 | optional_str = "?" if token.optional else ""
142 | value = token.lark_value if isinstance(token, RegexToken) else token.value
143 | return value + optional_str
144 | else:
145 | assert isinstance(token, MacroToken)
146 | return token.value
147 |
148 |
149 | def render_expansion(rhs: Expansion) -> str:
150 | return " ".join(render_token(t) for t in rhs)
151 |
152 |
153 | def render_expansions(expansions: Tuple[Expansion, ...]):
154 | return " | ".join(render_expansion(rhs) for rhs in expansions)
155 |
156 |
157 | def render_rule(rule: Union[Rule, Macro]) -> str:
158 | if isinstance(rule, Macro):
159 | arg_str = f"({', '.join(rule.args)})" if rule.args else ""
160 | return f"{rule.name}{arg_str} 2> {render_expansion(rule.expansion)}"
161 | elif isinstance(rule, PlanRule):
162 | return f"{rule.lhs} 2> {render_expansion(rule.rhs)}"
163 | elif isinstance(rule, UtteranceRule):
164 | return f"{rule.lhs} 1> {render_expansions(rule.utterance_rhss)}"
165 | else:
166 | assert isinstance(rule, SyncRule)
167 | return f"{rule.lhs} -> {render_expansions(rule.utterance_rhss)} , {render_expansion(rule.plan_rhs)}"
168 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/calflow/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import dataclasses
5 | import json
6 | import signal
7 | from dataclasses import dataclass
8 | from enum import Enum
9 | from itertools import islice
10 | from operator import itemgetter
11 | from typing import Dict, Iterable, List, Optional, Set
12 |
13 | from dataflow.core.lispress import (
14 | lispress_to_program,
15 | parse_lispress,
16 | program_to_lispress,
17 | render_compact,
18 | try_round_trip,
19 | )
20 | from lark import GrammarError, ParseError, UnexpectedCharacters
21 |
22 | from semantic_parsing_with_constrained_lm.util.types import StrPath
23 | from semantic_parsing_with_constrained_lm.scfg.generate import parse_and_render
24 | from semantic_parsing_with_constrained_lm.scfg.scfg import SCFG
25 | from semantic_parsing_with_constrained_lm.datum import FullDatum
26 | from semantic_parsing_with_constrained_lm.domains.calflow.disambiguate import score_auto_grammar_plan
27 | from semantic_parsing_with_constrained_lm.eval import TopKExactMatch
28 |
29 |
30 | class CalflowOutputLanguage(str, Enum):
31 | Canonical = "canonicalUtterance"
32 | Lispress = "lispress"
33 |
34 |
35 | @dataclass(frozen=True)
36 | class CalflowDatum(FullDatum):
37 | lispress: str
38 |
39 |
40 | def read_calflow_jsonl(
41 | filename: StrPath,
42 | model_output_type: CalflowOutputLanguage,
43 | whitelisted_dialogue_ids: Optional[Set[str]] = None,
44 | ) -> List[CalflowDatum]:
45 | """
46 | Reads CalflowDatum lists from file `filename` with `canonical` based on model_output_type. Selects based on
47 | whitelisted_dialogue_ids when set, reads all data otherwise.
48 | """
49 | with open(filename) as test_file:
50 | return [
51 | CalflowDatum(
52 | agent_context=json_line.get("context", ""),
53 | natural=json_line["utterance"],
54 | canonical=json_line[model_output_type],
55 | dialogue_id=json_line["dialogueId"],
56 | turn_part_index=json_line["turnIndex"],
57 | lispress=json_line["lispress"],
58 | )
59 | for line in test_file
60 | for json_line in [json.loads(line)]
61 | if whitelisted_dialogue_ids is None
62 | or json_line["dialogueId"] in whitelisted_dialogue_ids
63 | ]
64 |
65 |
66 | def predict_plan_from_canonical(
67 | scfg: SCFG,
68 | utterance: str,
69 | k: int = 1000,
70 | max_depth: int = 15,
71 | fallback_plan: str = " (FenceScope)",
72 | ) -> str:
73 | """
74 | Predicts a single Lispress surface string from the given canonical
75 | `utterance`.
76 | Finds possible parses using `scfg`, truncates to the top `k`, then picks
77 | the highest scoring possible parse under `score_auto_grammar_plan`.
78 | """
79 | unscored_plans: Iterable[str]
80 | try:
81 | unscored_plans = parse_and_render(
82 | scfg, utterance, source_is_plan=False, max_depth=max_depth
83 | )
84 | except (GrammarError, UnexpectedCharacters, ParseError, AttributeError):
85 | unscored_plans = []
86 | try:
87 | scored_plans = (
88 | (plan, score_auto_grammar_plan(plan)) for plan in unscored_plans
89 | )
90 | filtered_plans = (
91 | (plan, score) for plan, score in scored_plans if score != -float("inf")
92 | )
93 | truncated_plans = islice(filtered_plans, k)
94 | except AttributeError:
95 | truncated_plans = iter([])
96 | try:
97 | best_plan, _best_score = max(truncated_plans, key=itemgetter(1))
98 | except ValueError:
99 | # no candidates
100 | best_plan = fallback_plan
101 |
102 | # Remove leading space
103 | # TODO: Remove need for this by removing the space from the grammar
104 | assert len(best_plan) > 0 and best_plan[0] == " "
105 | best_plan = best_plan[1:]
106 |
107 | return best_plan
108 |
109 |
110 | class ParseTimeout(Exception):
111 | pass
112 |
113 |
114 | @dataclass
115 | class CalflowMetrics(TopKExactMatch[CalflowDatum]):
116 | scfg: SCFG
117 | data_type: CalflowOutputLanguage
118 |
119 | cached_parses: Dict[str, str] = dataclasses.field(default_factory=dict)
120 |
121 | # Only attempt to convert predictions with the same length as the gold,
122 | # which saves a lot of time with parsing.
123 | require_exact_length: bool = False
124 |
125 | @staticmethod
126 | def parse_timeout_handler(sig, frame):
127 | raise ParseTimeout
128 |
129 | def cached_parse(self, pred: str, gold: Optional[str]) -> str:
130 | """
131 | Given a canonical utterance, convert it into a lispress plan.
132 | """
133 | if pred in self.cached_parses:
134 | return self.cached_parses[pred]
135 | if self.require_exact_length and (gold is None or len(pred) != len(gold)):
136 | return "(FenceScope)"
137 |
138 | prev_signal = signal.getsignal(signal.SIGALRM)
139 | if prev_signal == signal.SIG_DFL:
140 | signal.signal(signal.SIGALRM, self.parse_timeout_handler)
141 | signal.alarm(300)
142 |
143 | try:
144 | predicted_plan = predict_plan_from_canonical(self.scfg, " " + pred)
145 | signal.signal(signal.SIGALRM, prev_signal)
146 | except Exception as e: # pylint: disable=broad-except
147 | print(e)
148 | predicted_plan = "(FenceScope)"
149 | finally:
150 | signal.signal(signal.SIGALRM, prev_signal)
151 |
152 | self.cached_parses[pred] = predicted_plan
153 | return predicted_plan
154 |
155 | def _is_correct(self, pred: str, target: CalflowDatum) -> bool:
156 | if self.data_type == CalflowOutputLanguage.Canonical:
157 | predicted_plan = self.cached_parse(pred, target.canonical)
158 | else:
159 | try:
160 | # round-trip to canonicalize
161 | predicted_plan = render_compact(
162 | program_to_lispress(lispress_to_program(parse_lispress(pred), 0)[0])
163 | )
164 | except Exception: # pylint: disable=W0703
165 | predicted_plan = "(FenceScope)"
166 |
167 | return try_round_trip(target.lispress) == try_round_trip(predicted_plan)
168 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/create_benchclamp_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import dataclasses
5 | from pathlib import Path
6 | from typing import Dict, List
7 |
8 | import tqdm
9 | from transformers import GPT2Tokenizer
10 |
11 | from semantic_parsing_with_constrained_lm.scfg.scfg import SCFG
12 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum
13 | from semantic_parsing_with_constrained_lm.decoding.earley_partial_parse import (
14 | GrammarTokenizerInfo,
15 | UTF8EarleyPartialParse,
16 | )
17 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import BenchClampDataset
18 | from semantic_parsing_with_constrained_lm.domains.create_benchclamp_splits import (
19 | can_force_decode,
20 | create_benchclamp_splits,
21 | )
22 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.dialogue import (
23 | canonicalize_sql_with_grammar,
24 | convert_cosql_to_datum_format,
25 | convert_spider_to_datum_format,
26 | load_cosql_data,
27 | load_spider_data,
28 | )
29 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.grammar import (
30 | load_base_grammar,
31 | preprocessed_grammar_for_schema,
32 | )
33 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.schema import DbSchema, load_schemas
34 | from semantic_parsing_with_constrained_lm.domains.sql.sql_datum import SqlDatum
35 | from semantic_parsing_with_constrained_lm.paths import (
36 | BENCH_CLAMP_PROCESSED_DATA_DIR,
37 | BENCH_CLAMP_RAW_DATA_DIR,
38 | )
39 | from semantic_parsing_with_constrained_lm.tokenization import GPT2ClampTokenizer
40 |
41 |
42 | def write_data_and_test_grammar(
43 | train_data: List[BenchClampDatum],
44 | dev_data: List[BenchClampDatum],
45 | schemas: Dict[str, DbSchema],
46 | datum_output_dir: Path,
47 | ) -> None:
48 | base_grammar = load_base_grammar()
49 | pre_grammars = {
50 | name: preprocessed_grammar_for_schema(db, base_grammar)
51 | for name, db in schemas.items()
52 | }
53 | grammars = {name: SCFG(pg) for name, pg in pre_grammars.items()}
54 | train_data_with_canonical_sql: List[BenchClampDatum] = []
55 | dev_data_with_canonical_sql: List[BenchClampDatum] = []
56 | print("Canonicalizing SQL ...")
57 | for data, data_with_canonical_sql in [
58 | (train_data, train_data_with_canonical_sql),
59 | (dev_data, dev_data_with_canonical_sql),
60 | ]:
61 | for datum in tqdm.tqdm(data):
62 | grammar = grammars[datum.schema_name]
63 | data_with_canonical_sql.append(
64 | dataclasses.replace(
65 | datum, plan=canonicalize_sql_with_grammar(datum.plan, grammar)
66 | )
67 | )
68 |
69 | # Create data splits
70 | print("Creating data splits ...")
71 | create_benchclamp_splits(
72 | train_data_with_canonical_sql,
73 | dev_data_with_canonical_sql,
74 | None,
75 | datum_output_dir,
76 | )
77 |
78 | print("Testing ...")
79 | clamp_tokenizer = GPT2ClampTokenizer(GPT2Tokenizer.from_pretrained("gpt2"))
80 | grammar_tok_info = {
81 | name: GrammarTokenizerInfo.create(clamp_tokenizer, preprocessed_grammar, True)
82 | for name, preprocessed_grammar in pre_grammars.items()
83 | }
84 | partial_parse_builder = lambda datum: UTF8EarleyPartialParse.initial(
85 | grammar_tok_info[datum.schema_name], datum.natural
86 | )
87 | total = 0
88 | wrong = 0
89 | print("Testing if force decoding possible for first 100 examples")
90 | for datum in train_data_with_canonical_sql:
91 | if total >= 100:
92 | break
93 | total += 1
94 | if not can_force_decode(
95 | clamp_tokenizer,
96 | partial_parse_builder,
97 | SqlDatum(
98 | natural=datum.utterance,
99 | canonical=datum.plan,
100 | dialogue_id="",
101 | turn_part_index=0,
102 | agent_context="",
103 | schema_name=datum.schema_name, # type: ignore
104 | ),
105 | ):
106 | print(f"Error: {datum.plan}")
107 | print(f"Schema: {datum.schema_name}")
108 | print(f"Utterance: {datum.utterance}")
109 | print()
110 | wrong += 1
111 |
112 | print(f"Force Decode Errors %: {wrong} / {total}")
113 |
114 |
115 | def main():
116 | raw_spider_dir = BENCH_CLAMP_RAW_DATA_DIR / BenchClampDataset.Spider.value
117 | spider_schemas = load_schemas(
118 | schemas_path=raw_spider_dir / "tables.json", db_path=raw_spider_dir / "database"
119 | )
120 | spider_train, train_others, spider_dev = [
121 | convert_spider_to_datum_format(
122 | load_spider_data(raw_spider_dir / fn),
123 | db_map=spider_schemas,
124 | db_path=str(raw_spider_dir / "database"),
125 | )
126 | for fn in ["train_spider.json", "train_others.json", "dev.json"]
127 | ]
128 | spider_train.extend(
129 | [
130 | dataclasses.replace(datum, dialogue_id=f"other-{datum.dialogue_id}")
131 | for datum in train_others
132 | ]
133 | )
134 | write_data_and_test_grammar(
135 | train_data=spider_train,
136 | dev_data=spider_dev,
137 | schemas=spider_schemas,
138 | datum_output_dir=BENCH_CLAMP_PROCESSED_DATA_DIR
139 | / BenchClampDataset.Spider.value,
140 | )
141 |
142 | raw_cosql_dir = BENCH_CLAMP_RAW_DATA_DIR / BenchClampDataset.CoSQL.value
143 | cosql_schemas = load_schemas(
144 | schemas_path=raw_cosql_dir / "tables.json", db_path=raw_cosql_dir / "database"
145 | )
146 | cosql_train, cosql_dev = [
147 | convert_cosql_to_datum_format(
148 | load_cosql_data(raw_cosql_dir / "sql_state_tracking" / fn),
149 | db_map=cosql_schemas,
150 | db_path=str(raw_cosql_dir / "database"),
151 | )
152 | for fn in ["cosql_train.json", "cosql_dev.json"]
153 | ]
154 | write_data_and_test_grammar(
155 | train_data=cosql_train,
156 | dev_data=cosql_dev,
157 | schemas=cosql_schemas,
158 | datum_output_dir=BENCH_CLAMP_PROCESSED_DATA_DIR / BenchClampDataset.CoSQL.value,
159 | )
160 |
161 |
162 | if __name__ == "__main__":
163 | main()
164 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import csv
5 | import dataclasses
6 | from abc import ABC, abstractmethod
7 | from contextlib import AbstractContextManager
8 | from dataclasses import dataclass
9 | from pathlib import Path
10 | from typing import Dict, Generic, List, Optional, Sequence, TextIO, TypeVar
11 |
12 | from semantic_parsing_with_constrained_lm.datum import FullDatum, FullDatumSub
13 | from semantic_parsing_with_constrained_lm.model import ModelResult
14 |
15 | Pred = TypeVar("Pred")
16 | Target = TypeVar("Target")
17 |
18 |
19 | # TODO: Replcae this with a more flexible function suited to each domain
20 | def exact_match_with_logging(
21 | test_datum: FullDatum, kbest: Sequence[ModelResult]
22 | ) -> bool:
23 | gold = (
24 | test_datum.canonical.strip(" ")
25 | if test_datum.canonical is not None
26 | else "UNREACHABLE"
27 | )
28 | pred = kbest[0].text.strip(" ") if kbest else ""
29 | print()
30 | print(f"context: {test_datum.agent_context}")
31 | print(f"natural: {test_datum.natural}")
32 | print(f"predicted: {pred}")
33 | print(f"gold: {gold}")
34 | result = gold == pred
35 | print(f"is correct: {result}")
36 | beam_result = False
37 | for i, pred_i in enumerate(kbest):
38 | stripped = pred_i.text.strip(" ")
39 | beam_result = beam_result or gold == stripped
40 | print(f"Beam {i} [{pred_i.cost:.3f}]: {stripped}")
41 | print(f"is correct@{i}: {beam_result}")
42 | print()
43 | return result
44 |
45 |
46 | class Metric(Generic[Pred, Target], ABC):
47 | """Used to measure goodness of model results compared to the ground truth.
48 |
49 | Stateful over the duration of an experiment run."""
50 |
51 | @abstractmethod
52 | def update(self, pred: Pred, target: Target) -> Dict[str, Optional[str]]:
53 | """Uses `target` and the model predictions `pred` to update the state."""
54 | pass
55 |
56 | @abstractmethod
57 | def compute(self) -> Dict[str, float]:
58 | """Uses the state to compute the final results."""
59 | pass
60 |
61 | @abstractmethod
62 | def reset(self) -> None:
63 | """Reinitializes the state."""
64 | pass
65 |
66 |
67 | @dataclass
68 | class TopKExactMatch(Metric[Sequence[str], FullDatumSub]):
69 | k: int
70 | correct: List[int] = dataclasses.field(init=False)
71 | total: int = dataclasses.field(init=False)
72 |
73 | def __post_init__(self):
74 | self.reset()
75 |
76 | def _is_correct(self, pred: str, target: FullDatumSub) -> bool:
77 | """Can be overridden by child classes."""
78 | return pred == target.canonical
79 |
80 | def update(
81 | self, preds: Sequence[str], target: FullDatumSub
82 | ) -> Dict[str, Optional[str]]:
83 | self.total += 1
84 | found_correct = False
85 | result: Dict[str, Optional[str]] = {}
86 | for i, pred in enumerate(preds[: self.k]):
87 | correct = self._is_correct(pred, target)
88 | found_correct |= correct
89 | self.correct[i] += found_correct
90 | result[f"rank{i + 1}"] = "correct" if correct else "incorrect"
91 | result[f"top{i + 1}"] = "correct" if found_correct else "incorrect"
92 |
93 | # Handle when we have fewer predictions than self.k
94 | for i in range(len(preds), self.k):
95 | self.correct[i] += found_correct
96 | result[f"rank{i + 1}"] = "incorrect"
97 | result[f"top{i + 1}"] = "correct" if found_correct else "incorrect"
98 |
99 | return result
100 |
101 | def compute(self) -> Dict[str, float]:
102 | result = {}
103 | for i in range(self.k):
104 | result[f"top{i + 1}"] = self.correct[i] / self.total
105 | return result
106 |
107 | def reset(self) -> None:
108 | self.correct = [0] * self.k
109 | self.total = 0
110 |
111 |
112 | class Logger(Generic[Pred, Target], AbstractContextManager):
113 | """Experiment logger interface for capturing model outputs for a
114 | given target and logging them in some form. Useful for things like
115 | producing error tables.
116 |
117 | The logger implements context manager and must be wrapped in a
118 | `with` statement to be used.
119 | """
120 |
121 | @abstractmethod
122 | def log(self, predictions: Pred, target: Target, metrics: Dict[str, Optional[str]]):
123 | pass
124 |
125 |
126 | class ExactMatchErrorLogger(Logger[Sequence[ModelResult], FullDatumSub]):
127 | """A logger that logs in 2 files:
128 | - base_path / errors.csv - CSV with model error using exact match metric.
129 | - base_path / correct.scv - CSV with correct model predictions.
130 | """
131 |
132 | def __init__(self, base_path: Path):
133 | self._base_path = base_path
134 | self._err_file: Optional[TextIO] = None
135 | self._err_writer = None
136 | self._corr_file: Optional[TextIO] = None
137 | self._corr_writer = None
138 |
139 | def __enter__(self):
140 | assert not self._err_file and not self._corr_file
141 | self._err_file = open(self._base_path / "errors.csv", "w")
142 | self._err_writer = csv.writer(
143 | self._err_file, delimiter=",", quoting=csv.QUOTE_MINIMAL
144 | )
145 | self._corr_file = open(self._base_path / "correct.csv", "w")
146 | self._corr_writer = csv.writer(
147 | self._corr_file, delimiter=",", quoting=csv.QUOTE_MINIMAL
148 | )
149 |
150 | self._err_writer.writerow(["Natural", "Gold", "Predicted"]) # type: ignore
151 | self._corr_writer.writerow(["Natural", "Gold"]) # type: ignore
152 |
153 | def __exit__(self, *_):
154 | assert self._err_file and self._corr_file
155 | self._err_file.close()
156 | self._corr_file.close()
157 |
158 | def log(
159 | self,
160 | predictions: Sequence[ModelResult],
161 | target: FullDatumSub,
162 | metrics: Dict[str, Optional[str]],
163 | ):
164 | pred = predictions[0].text if len(predictions) else ""
165 | if pred != target.canonical:
166 | self._err_writer.writerow([target.natural, target.canonical, pred]) # type: ignore
167 | self._err_file.flush() # type: ignore
168 | else:
169 | self._corr_writer.writerow([target.natural, target.canonical]) # type: ignore
170 | self._corr_file.flush() # type: ignore
171 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/sql/cosql/schema.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | from collections import defaultdict
6 | from dataclasses import dataclass
7 | from enum import Enum
8 | from json import load
9 | from typing import Any, Dict, List, Tuple
10 |
11 | import jsons
12 |
13 | from semantic_parsing_with_constrained_lm.util.types import StrPath
14 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.content_encoder import get_column_picklist
15 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.paths import COSQL_DIR, SCHEMAS_FILE
16 |
17 |
18 | class ColumnType(Enum):
19 | Text = "text"
20 | Number = "number"
21 | Time = "time"
22 | Boolean = "boolean"
23 | Others = "others"
24 |
25 |
26 | @dataclass(frozen=True)
27 | class Column:
28 | name: str
29 | tpe: ColumnType
30 | # name in more natural language
31 | nl_name: str = ""
32 |
33 | @staticmethod
34 | def star() -> "Column":
35 | return Column(name="*", tpe=ColumnType.Text, nl_name="*")
36 |
37 |
38 | @dataclass(frozen=True)
39 | class ForeignKey:
40 | column_id: int
41 | other_column_id: int
42 |
43 |
44 | @dataclass(frozen=True)
45 | class Table:
46 | name: str
47 | columns: List[Column]
48 | # name in more natural language
49 | nl_name: str = ""
50 |
51 | def all_columns(self) -> List[Column]:
52 | return [Column.star()] + self.columns
53 |
54 |
55 | @dataclass(frozen=True)
56 | class DbSchema:
57 | name: str
58 | tables: List[Table]
59 | # indexes into self.tables
60 | columns: List[Tuple[int, Column]] = () # type: ignore
61 | # indexes into self.columns
62 | primary_keys: List[int] = () # type: ignore
63 | # indexes into self.columns
64 | foreign_keys: List[ForeignKey] = () # type: ignore
65 | # values in the database
66 | values: List[str] = () # type: ignore
67 |
68 | @staticmethod
69 | def from_json(schema_json: Dict[str, Any], db_path: str) -> "DbSchema":
70 | columns: List[Tuple[int, Column]] = [
71 | (t_id, Column(name=orig, tpe=ColumnType(tpe), nl_name=name))
72 | for (t_id, orig), (_, name), tpe in zip(
73 | schema_json["column_names_original"],
74 | schema_json["column_names"],
75 | schema_json["column_types"],
76 | )
77 | ]
78 | columns_by_table: Dict[int, List[Column]] = defaultdict(list)
79 | for t_id, col in columns:
80 | columns_by_table[t_id].append(col)
81 | # TODO:
82 | # tables.json is corrupted in the CoSQL dataset for schema formula_1.
83 | # "table_names" and "table_names_original" are not collated:
84 | # "table_names": [ "races", "drivers", "status", "seasons", "constructors",
85 | # "constructor standings", "results", "driver standings", "constructor results", "qualifying",
86 | # "circuits", "pitstops", "laptimes" ],
87 | # "table_names_original": [ "circuits", "races", "drivers", "status", "seasons", "constructors",
88 | # "constructorStandings", "results", "driverStandings", "constructorResults", "qualifying",
89 | # "pitStops", "lapTimes" ]
90 | # in that case, nl_name will be messed up for the table
91 | tables = [
92 | Table(name=orig, columns=columns_by_table[t_id], nl_name=name)
93 | for t_id, (orig, name) in enumerate(
94 | zip(schema_json["table_names_original"], schema_json["table_names"])
95 | )
96 | ]
97 | foreign_keys = [
98 | ForeignKey(col, other) for col, other in schema_json["foreign_keys"]
99 | ]
100 | db_id = schema_json["db_id"]
101 | values = []
102 | # Some tables mentioned in tables,json are not present in the download
103 | # This check catches errors when trying to read them.
104 | if os.path.exists(db_path + "/" + db_id):
105 | db_path = db_path + "/" + db_id + "/" + db_id + ".sqlite"
106 | for table in tables:
107 | for column in table.columns:
108 | picklist = get_column_picklist(table.name, column.name, db_path)
109 | values.extend(
110 | [
111 | val
112 | for val in picklist
113 | # this condition removes times from the list
114 | if isinstance(val, str) and not val.count(":") == 2
115 | ]
116 | )
117 |
118 | return DbSchema(
119 | name=schema_json["db_id"],
120 | columns=columns,
121 | foreign_keys=foreign_keys,
122 | primary_keys=schema_json["primary_keys"],
123 | tables=tables,
124 | values=values,
125 | )
126 |
127 |
128 | def load_schemas(
129 | schemas_path: StrPath = SCHEMAS_FILE, db_path: StrPath = COSQL_DIR / "database"
130 | ) -> Dict[str, DbSchema]:
131 | db_schema_details_file = str(db_path) + "/" + "db_schema_details.json"
132 | if os.path.exists(db_schema_details_file):
133 | with open(db_schema_details_file, "r") as db_schema_details_fp:
134 | return jsons.loads(db_schema_details_fp.read(), cls=Dict[str, DbSchema])
135 |
136 | with open(schemas_path) as tables_file:
137 | schemas_json = load(tables_file)
138 | schemas = [
139 | DbSchema.from_json(schema_json, str(db_path)) for schema_json in schemas_json
140 | ]
141 | return {schema.name: schema for schema in schemas}
142 |
143 |
144 | if __name__ == "__main__":
145 | cosql_schemas = load_schemas(
146 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/CoSQL/tables.json",
147 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/CoSQL/database/",
148 | )
149 | with open(
150 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/CoSQL/database/db_schema_details.json",
151 | "w",
152 | ) as fp:
153 | fp.write(jsons.dumps(cosql_schemas, cls=Dict[str, DbSchema]))
154 |
155 | spider_schemas = load_schemas(
156 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/Spider/tables.json",
157 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/Spider/database/",
158 | )
159 | with open(
160 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/Spider/database/db_schema_details.json",
161 | "w",
162 | ) as fp:
163 | fp.write(jsons.dumps(cosql_schemas, cls=Dict[str, DbSchema]))
164 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/configs/lib/calflow.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import functools
5 | from typing import Callable, Dict, List, Optional, Tuple
6 |
7 | from transformers.tokenization_utils import PreTrainedTokenizer
8 |
9 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import PreprocessedGrammar
10 | from semantic_parsing_with_constrained_lm.configs.lib.common import (
11 | PromptOrder,
12 | SeparateLM,
13 | make_semantic_parser,
14 | )
15 | from semantic_parsing_with_constrained_lm.datum import Datum
16 | from semantic_parsing_with_constrained_lm.decoding.earley_partial_parse import (
17 | GrammarTokenizerInfo,
18 | UTF8EarleyPartialParse,
19 | )
20 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import StartsWithSpacePartialParse
21 | from semantic_parsing_with_constrained_lm.domains.calflow import (
22 | CalflowDatum,
23 | CalflowOutputLanguage,
24 | read_calflow_jsonl,
25 | )
26 | from semantic_parsing_with_constrained_lm.fewshot import PromptBuilder
27 | from semantic_parsing_with_constrained_lm.lm import (
28 | AutoregressiveModel,
29 | ClientType,
30 | IncrementalLanguageModel,
31 | )
32 | from semantic_parsing_with_constrained_lm.model import (
33 | BeamSearchSemanticParser,
34 | DecodingSetup,
35 | ProblemFactory,
36 | )
37 |
38 | # This is a magic number computed by Chris, the origins of which are no
39 | # longer exactly known. It should correspond to the maximum number of GPT-2
40 | # tokens we expect any plan would take up, when written as Lispress.
41 | #
42 | # We truncate the number of training examples put inside the prompt so that
43 | # we can add this many more tokens and still stay under the 2048 limit.
44 | #
45 | # When using canonical utterances, this number doesn't need to be as big
46 | # since canonical utterances are not as long as Lispress. However, when
47 | # using 20 training examples per prompt, we are never at risk of reaching
48 | # the 2048 limit, so this point is moot.
49 | MAX_STEPS_FOR_COMPLETION = 313
50 |
51 |
52 | calflow_max_steps_fn_params: Dict[
53 | Tuple[CalflowOutputLanguage, ClientType], Tuple[float, float]
54 | ] = {
55 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \
56 | # --data-path \
57 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \
58 | # --tokenizer facebook/bart-large --output-type canonicalUtterance \
59 | # --max-unreachable 3
60 | (CalflowOutputLanguage.Canonical, ClientType.BART): (8, 1.7233333333),
61 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \
62 | # --data-path \
63 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \
64 | # --tokenizer gpt2-xl --output-type canonicalUtterance \
65 | # --max-unreachable 3
66 | (CalflowOutputLanguage.Canonical, ClientType.GPT3): (8, 1.7233333333),
67 | (CalflowOutputLanguage.Canonical, ClientType.SMGPT3): (8, 1.7233333333),
68 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \
69 | # --data-path \
70 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \
71 | # --tokenizer facebook/bart-large --output-type lispress \
72 | # --max-unreachable 3
73 | (CalflowOutputLanguage.Lispress, ClientType.BART): (65, 7.084487179487172),
74 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \
75 | # --data-path \
76 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \
77 | # --tokenizer gpt2-xl --output-type lispress --max-unreachable 3
78 | (CalflowOutputLanguage.Lispress, ClientType.GPT3): (65, 7.084487179487172),
79 | (CalflowOutputLanguage.Lispress, ClientType.SMGPT3): (65, 7.084487179487172),
80 | }
81 |
82 |
83 | def get_calflow_max_steps_fn(
84 | output_type: CalflowOutputLanguage,
85 | client_type: ClientType,
86 | tokenizer: PreTrainedTokenizer,
87 | ) -> Callable[[Datum], Optional[int]]:
88 | max_steps_intercept, max_steps_slope = calflow_max_steps_fn_params[
89 | output_type, client_type
90 | ]
91 |
92 | def fn(datum: Datum) -> Optional[int]:
93 | return min(
94 | int(
95 | len(tokenizer.tokenize(datum.natural)) * max_steps_slope
96 | + max_steps_intercept
97 | ),
98 | MAX_STEPS_FOR_COMPLETION,
99 | )
100 |
101 | return fn
102 |
103 |
104 | def make_semantic_parser_for_calflow(
105 | train_data: List[CalflowDatum],
106 | lm: AutoregressiveModel,
107 | use_gpt3: bool,
108 | beam_size: int,
109 | output_type: CalflowOutputLanguage,
110 | client_type: ClientType,
111 | preprocessed_grammar: PreprocessedGrammar,
112 | constrained: bool,
113 | prompt_order: PromptOrder = PromptOrder.BestLast,
114 | # Settings for using autoregressive models in a few-shot in-context setting
115 | similarity_lm: Optional[IncrementalLanguageModel] = None,
116 | prompt_builder: Optional[PromptBuilder] = None,
117 | num_examples_per_prompt: int = 20,
118 | problem_factory_builder: Optional[Callable[[DecodingSetup], ProblemFactory]] = None,
119 | ) -> BeamSearchSemanticParser:
120 | if constrained:
121 | grammar_tokenizer_info = GrammarTokenizerInfo.create(
122 | lm.tokenizer,
123 | preprocessed_grammar,
124 | output_type == CalflowOutputLanguage.Lispress,
125 | )
126 | # TODO: Refer to `lm` to decide whether to use UTF8EarleyPartialParse or a different variant
127 | partial_parse_builder = lambda datum: UTF8EarleyPartialParse.initial(
128 | grammar_tokenizer_info, datum.natural
129 | )
130 | else:
131 | # TODO: Only impose this if we are using a GPT-2-style tokenizer
132 | partial_parse = StartsWithSpacePartialParse(lm.tokenizer)
133 | partial_parse_builder = lambda _: partial_parse
134 |
135 | max_steps_fn = get_calflow_max_steps_fn(output_type, client_type, lm.tokenizer)
136 |
137 | return make_semantic_parser(
138 | train_data,
139 | lm,
140 | use_gpt3,
141 | MAX_STEPS_FOR_COMPLETION,
142 | beam_size,
143 | partial_parse_builder,
144 | max_steps_fn=max_steps_fn,
145 | prompt_order=prompt_order,
146 | similarity_method=SeparateLM(similarity_lm=similarity_lm), # type: ignore
147 | prompt_builder=prompt_builder,
148 | num_examples_per_prompt=num_examples_per_prompt,
149 | problem_factory_builder=problem_factory_builder,
150 | )
151 |
152 |
153 | cached_read_calflow_jsonl = functools.lru_cache(maxsize=None)(read_calflow_jsonl)
154 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/decoding/uint8_earley_partial_parse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """Constrained decoding for Earley grammars where terminals are np.uint8.
5 |
6 | UInt8EarleyPartialParse is similar to EarleyPartialParse, except that it only
7 | works for grammars where all terminals are np.uint8.
8 | It also currently lacks support for input utterance copying constraints.
9 |
10 | TODO: We can generalize UInt8EarleyPartialParse to "Atomic"EarleyPartialParse by
11 | replacing np.uint8 with a type parameter.
12 | """
13 |
14 | import dataclasses
15 | import itertools
16 | from dataclasses import dataclass
17 | from typing import Mapping # pylint: disable=unused-import
18 | from typing import (
19 | Any,
20 | Callable,
21 | Dict,
22 | Iterable,
23 | Iterator,
24 | Optional,
25 | Sequence,
26 | Tuple,
27 | TypeVar,
28 | )
29 |
30 | import numpy as np
31 | import torch
32 | from cached_property import cached_property
33 |
34 | from semantic_parsing_with_constrained_lm.earley.earley import EarleyChart
35 | from semantic_parsing_with_constrained_lm.earley.grammar import Grammar
36 | from semantic_parsing_with_constrained_lm.earley.input import Position, SigmaStarTriePosition
37 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse
38 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
39 |
40 | T = TypeVar("T")
41 |
42 |
43 | def get_only(items: Iterable[T]) -> T:
44 | """Returns the single value in `items`.
45 |
46 | This function raises an exception if items contains 0 or 2+ items.
47 | """
48 | [item] = items
49 | return item
50 |
51 |
52 | @dataclass
53 | class UInt8GrammarNode:
54 | chart: EarleyChart[np.uint8, Any]
55 | lazy_pos: Callable[[], Position[np.uint8]]
56 |
57 | @cached_property
58 | def pos(self) -> Position[np.uint8]:
59 | return self.lazy_pos()
60 |
61 | @cached_property
62 | def children(self) -> "Mapping[np.uint8, UInt8GrammarNode]":
63 | # TODO: Consider using a list instead.
64 | return {
65 | terminal: UInt8GrammarNode(
66 | self.chart,
67 | lambda terminal=terminal, items=items: get_only(
68 | self.chart.advance_with_terminal(self.pos, terminal, items)
69 | ),
70 | )
71 | for terminal, items in self.chart.advance_only_nonterminals(
72 | self.pos, unpop_terminals=False
73 | ).items()
74 | }
75 |
76 | def advance(self, seq: Sequence[np.uint8]) -> "Optional[UInt8GrammarNode]":
77 | result = self
78 | for byte in seq:
79 | result = result.children.get(byte)
80 | if result is None:
81 | break
82 | return result
83 |
84 |
85 | @dataclass
86 | class UInt8GrammarTokenizerInfo:
87 | grammar: Grammar[np.uint8, Any]
88 | tokens: Sequence[Sequence[np.uint8]]
89 |
90 | @cached_property
91 | def vocab_size(self) -> int:
92 | return len(self.tokens)
93 |
94 | @staticmethod
95 | def from_clamp_tokenizer(
96 | grammar: Grammar[np.uint8, Any], tokenizer: ClampTokenizer
97 | ) -> "UInt8GrammarTokenizerInfo":
98 | encoded_tokens = UInt8GrammarTokenizerInfo.prepare_tokens_from_clamp_tokenizer(
99 | tokenizer
100 | )
101 | return UInt8GrammarTokenizerInfo(
102 | grammar,
103 | encoded_tokens,
104 | )
105 |
106 | @staticmethod
107 | def prepare_tokens_from_clamp_tokenizer(
108 | tokenizer: ClampTokenizer,
109 | ) -> Sequence[Sequence[np.uint8]]:
110 | return [
111 | np.frombuffer(tokenizer.id_to_utf8_token_map[i], dtype=np.uint8)
112 | for i in range(len(tokenizer.id_to_utf8_token_map))
113 | ]
114 |
115 |
116 | @dataclass
117 | class UInt8EarleyPartialParse(PartialParse):
118 | grammar_node: UInt8GrammarNode
119 | info: UInt8GrammarTokenizerInfo
120 | start_pos: Position[np.uint8]
121 | _next_node_cache: Dict[int, Optional[UInt8GrammarNode]] = dataclasses.field(
122 | default_factory=dict
123 | )
124 |
125 | def allowed_next(
126 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None
127 | ) -> Tuple[Optional[torch.Tensor], bool]:
128 | # TODO: Use optimizations already in EarleyPartialParse, and others identified but not implemented:
129 | # - Only check the first N tokens from ordered_ids with `token_id_is_valid`;
130 | # for the rest, intersect grammar_node with the vocab trie
131 | # - Cross-beam pruning: https://semanticmachines.slack.com/archives/C0310DTKR6J/p1644449621986019
132 | assert ordered_ids is not None
133 | ordered_ids_list = ordered_ids.tolist()
134 | all_tokens = self.info.tokens
135 | vocab_size = self.info.vocab_size
136 | node = self.grammar_node
137 |
138 | def token_id_is_valid(i: int) -> bool:
139 | if not 0 <= i < vocab_size:
140 | return False
141 | next_node = node.advance(all_tokens[i])
142 | self._next_node_cache[i] = next_node
143 | return next_node is not None
144 |
145 | def produce_valid_tokens() -> Iterator[int]:
146 | for i in ordered_ids_list:
147 | if token_id_is_valid(i):
148 | yield i
149 |
150 | # TODO: Add special case where grammar_node.children has no elements
151 | # (i.e. tokens_list will be empty)
152 | tokens_list = list(itertools.islice(produce_valid_tokens(), top_k))
153 | can_end = self.grammar_node.chart.was_found(
154 | self.grammar_node.chart.grammar.root, self.start_pos, self.grammar_node.pos
155 | )
156 | return torch.tensor(tokens_list, dtype=torch.long), can_end
157 |
158 | def append(self, token: int) -> "UInt8EarleyPartialParse":
159 | """Return a new PartialParse created by appending this token."""
160 | if token in self._next_node_cache:
161 | node_for_result = self._next_node_cache[token]
162 | else:
163 | if not 0 <= token < self.info.vocab_size:
164 | raise ValueError("token was not in the vocabulary")
165 | node_for_result = self.grammar_node.advance(self.info.tokens[token])
166 |
167 | if node_for_result is None:
168 | raise ValueError("invalid token to continue with")
169 |
170 | return UInt8EarleyPartialParse(node_for_result, self.info, self.start_pos)
171 |
172 | @staticmethod
173 | def initial(info: UInt8GrammarTokenizerInfo) -> "UInt8EarleyPartialParse":
174 | chart = EarleyChart(info.grammar, use_backpointers=False)
175 | start_pos = SigmaStarTriePosition[np.uint8]()
176 | chart.seek(info.grammar.root, start_pos)
177 | grammar_node = UInt8GrammarNode(chart, lambda: start_pos)
178 | return UInt8EarleyPartialParse(grammar_node, info, start_pos)
179 |
--------------------------------------------------------------------------------
/run/semantic_parsing_with_constrained_lm/domains/overnight/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import json
5 | from dataclasses import dataclass
6 | from enum import Enum
7 | from typing import Callable, Dict, List
8 |
9 | from blobfile import BlobFile
10 |
11 | from semantic_parsing_with_constrained_lm.util.trie import Trie
12 | from semantic_parsing_with_constrained_lm.util.types import StrPath
13 | from semantic_parsing_with_constrained_lm.datum import Datum, FullDatum
14 | from semantic_parsing_with_constrained_lm.decoding.trie_partial_parse import TriePartialParse
15 | # from semantic_parsing_with_constrained_lm.domains.calflow.write_data import CACHE_DIR
16 | from semantic_parsing_with_constrained_lm.eval import TopKExactMatch
17 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer
18 |
19 | # NOTE: get rid of the catflow/dataflow dependency
20 | from appdirs import user_cache_dir
21 | CACHE_DIR = user_cache_dir("semantic_parsing_as_constrained_lm")
22 |
23 |
24 | class OutputType(str, Enum):
25 | Utterance = "utterance"
26 | MeaningRepresentation = "meaningRepresentation"
27 |
28 |
29 | @dataclass
30 | class TopKDenotationMatch(TopKExactMatch[FullDatum]):
31 | canonical_to_denotation: Dict[str, str]
32 |
33 | def _is_correct(self, pred: str, datum: FullDatum) -> bool:
34 | target = datum.canonical
35 | pred_denotation = self.canonical_to_denotation.get(pred)
36 | target_denotation = self.canonical_to_denotation.get(target, None)
37 | if pred_denotation is None and target_denotation is None:
38 | return pred == target
39 | else:
40 | return pred_denotation == target_denotation
41 |
42 |
43 | @dataclass
44 | class OvernightPieces:
45 | train_data: List[FullDatum]
46 | test_data: List[FullDatum]
47 | partial_parse_builder: Callable[[Datum], TriePartialParse]
48 | denotation_metric: TopKDenotationMatch
49 | max_length: int
50 |
51 | @staticmethod
52 | def from_dir(
53 | tokenizer: ClampTokenizer,
54 | root_dir: StrPath,
55 | domain: str,
56 | is_dev: bool,
57 | k: int,
58 | output_type: OutputType = OutputType.Utterance,
59 | simplify_logical_forms=False,
60 | prefix_with_space=False,
61 | ) -> "OvernightPieces":
62 | data_pieces = OvernightDataPieces.from_dir(
63 | root_dir, domain, is_dev, output_type, simplify_logical_forms
64 | )
65 | decoder_pieces = OvernightDecoderPieces.create(
66 | data_pieces, tokenizer, k, prefix_with_space
67 | )
68 |
69 | return OvernightPieces(
70 | data_pieces.train_data,
71 | data_pieces.test_data,
72 | # https://github.com/python/mypy/issues/5485
73 | decoder_pieces.partial_parse_builder, # type: ignore
74 | decoder_pieces.denotation_metric,
75 | decoder_pieces.max_length,
76 | )
77 |
78 |
79 | @dataclass
80 | class OvernightDataPieces:
81 | train_data: List[FullDatum]
82 | test_data: List[FullDatum]
83 | target_output_to_denotation: Dict[str, str]
84 |
85 | @staticmethod
86 | def from_dir(
87 | root_dir: StrPath,
88 | domain: str,
89 | is_dev: bool,
90 | output_type: OutputType = OutputType.MeaningRepresentation,
91 | simplify_logical_forms: bool = False,
92 | ) -> "OvernightDataPieces":
93 | # TODO make this configurable?
94 | with BlobFile(str(root_dir) + f"/{domain}.canonical.json") as bf:
95 | canonical_data = json.load(bf)
96 |
97 | if output_type == OutputType.Utterance:
98 | target_output_to_denotation = {
99 | k: v["denotation"] for k, v in canonical_data.items()
100 | }
101 | datum_key = "canonical"
102 | elif output_type == OutputType.MeaningRepresentation:
103 | target_output_to_denotation = {}
104 | for program_info in canonical_data.values():
105 | formula = program_info["formula"]
106 | if formula is None:
107 | continue
108 | if simplify_logical_forms:
109 | formula = OvernightDataPieces.simplify_lf(formula)
110 | assert formula not in target_output_to_denotation
111 | target_output_to_denotation[formula] = program_info["denotation"]
112 | datum_key = "formula"
113 | else:
114 | raise ValueError(output_type)
115 |
116 | train_data, test_data = [
117 | [
118 | FullDatum(
119 | dialogue_id=f"{dataset_name}-{i}",
120 | turn_part_index=None,
121 | agent_context=None,
122 | natural=d["natural"],
123 | canonical=OvernightDataPieces.simplify_lf(d[datum_key])
124 | if simplify_logical_forms
125 | else d[datum_key],
126 | )
127 | for i, line in enumerate(
128 | BlobFile(path, streaming=False, cache_dir=CACHE_DIR)
129 | )
130 | for d in [json.loads(line)]
131 | ]
132 | for dataset_name, path in (
133 | (
134 | "train",
135 | f"{root_dir}/{domain}.train_with{'out' if is_dev else ''}_dev.jsonl",
136 | ),
137 | ("eval", f"{root_dir}/{domain}.{'dev' if is_dev else 'test'}.jsonl"),
138 | )
139 | ]
140 |
141 | return OvernightDataPieces(train_data, test_data, target_output_to_denotation)
142 |
143 | @staticmethod
144 | def simplify_lf(lf: str) -> str:
145 | return lf.replace("edu.stanford.nlp.sempre.overnight.SimpleWorld.", "")
146 |
147 |
148 | @dataclass
149 | class OvernightDecoderPieces:
150 | data_pieces: OvernightDataPieces
151 | partial_parse_builder: Callable[[Datum], TriePartialParse]
152 | denotation_metric: TopKDenotationMatch
153 | max_length: int
154 |
155 | @staticmethod
156 | def create(
157 | data_pieces: OvernightDataPieces,
158 | tokenizer: ClampTokenizer,
159 | k: int,
160 | prefix_with_space: bool = False,
161 | ) -> "OvernightDecoderPieces":
162 | if prefix_with_space:
163 | canonical_trie = Trie(
164 | tokenizer.encode(" " + canon)
165 | for canon in data_pieces.target_output_to_denotation
166 | )
167 | else:
168 | canonical_trie = Trie(
169 | tokenizer.encode(canon)
170 | for canon in data_pieces.target_output_to_denotation
171 | )
172 | partial_parse_builder = lambda _: TriePartialParse(canonical_trie)
173 |
174 | denotation_metric = TopKDenotationMatch(
175 | k, data_pieces.target_output_to_denotation
176 | )
177 | max_length = max(len(x) for x in canonical_trie)
178 |
179 | return OvernightDecoderPieces(
180 | data_pieces, partial_parse_builder, denotation_metric, max_length
181 | )
182 |
--------------------------------------------------------------------------------