├── run └── semantic_parsing_with_constrained_lm │ ├── domains │ ├── sql │ │ ├── grammar │ │ │ ├── start.scfg │ │ │ ├── table_columns.scfg │ │ │ ├── whitespace.scfg │ │ │ └── quoted.scfg │ │ ├── __init__.py │ │ ├── cosql │ │ │ ├── __init__.py │ │ │ ├── paths.py │ │ │ ├── seq2seq.py │ │ │ ├── grammar.py │ │ │ └── schema.py │ │ ├── sql_datum.py │ │ ├── sequence_creator.py │ │ ├── sql_metric.py │ │ └── create_benchclamp_data.py │ ├── calflow │ │ ├── grammar │ │ │ ├── start.scfg │ │ │ ├── entities.scfg │ │ │ ├── enum_wrappers.scfg │ │ │ ├── quoted.scfg │ │ │ └── fluenter.scfg │ │ ├── data │ │ │ └── ids_dev_100_uniform.txt │ │ └── __init__.py │ ├── __init__.py │ ├── mtop │ │ ├── __init__.py │ │ └── grammar.py │ ├── lispress_v2 │ │ ├── __init__.py │ │ ├── sequence_creator.py │ │ └── grammar.py │ ├── ltl │ │ ├── data │ │ │ ├── pick-golden-cross0-split0.canonical.json │ │ │ ├── pick-golden-cross0-split1.canonical.json │ │ │ ├── pick-golden-cross0-split2.canonical.json │ │ │ ├── pick-golden-cross0-split3.canonical.json │ │ │ ├── pick-golden-cross0-split4.canonical.json │ │ │ ├── pick-syn.canonical.json │ │ │ ├── pick-syn-aug.canonical.json │ │ │ └── pick-syn.train.jsonl │ │ ├── kept │ │ │ ├── pick-syn.canonical.json │ │ │ └── pick-syn.train.jsonl │ │ ├── create_benchclamp_data.py │ │ └── __init__.py │ ├── dfa_grammar_utils.py │ ├── overnight │ │ ├── create_benchclamp_data.py │ │ └── __init__.py │ ├── calflow_eval_utils.py │ └── create_benchclamp_splits.py │ ├── __init__.py │ ├── scfg │ ├── __init__.py │ ├── parser │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── types.py │ │ ├── scfg_grammar.lark │ │ ├── macro.py │ │ ├── token.py │ │ ├── rule.py │ │ └── parse.py │ ├── string_utils.py │ └── earley_grammar.py │ ├── configs │ ├── __init__.py │ └── lib │ │ ├── __init__.py │ │ ├── benchclamp.py │ │ └── calflow.py │ ├── decoding │ ├── __init__.py │ ├── trie_partial_parse.py │ ├── partial_parse.py │ └── uint8_earley_partial_parse.py │ ├── earley │ ├── __init__.py │ ├── cfg.lark │ ├── unicode_categories_spans.py │ ├── context_sensitive.py │ ├── specialization.py │ └── recognize.py │ ├── finetune │ ├── __init__.py │ ├── configs │ │ └── __init__.py │ ├── download_huggingface_lms.py │ └── check_for_unks.py │ ├── scripts │ ├── __init__.py │ └── calflow_fit_max_steps.py │ ├── async_tools │ ├── __init__.py │ └── batch_helper.py │ ├── index │ ├── __init__.py │ ├── index.py │ └── bm25_index.py │ ├── util │ ├── types.py │ ├── unit.py │ ├── keydefaultdict.py │ └── logger.py │ ├── sequence_creator.py │ ├── cache.py │ ├── playground.ipynb │ ├── result.py │ ├── paths.py │ ├── datum.py │ ├── train_model_setup.py │ ├── fit_max_steps.py │ └── eval.py ├── env_install.sh ├── docs ├── _config.yml ├── _layouts │ └── default.html └── index.md ├── datasets └── pick-and-place │ ├── canonical.json │ └── train_seed.jsonl └── README.md /run/semantic_parsing_with_constrained_lm/domains/sql/grammar/start.scfg: -------------------------------------------------------------------------------- 1 | start -> parse 2 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/start.scfg: -------------------------------------------------------------------------------- 1 | start -> " " unit , unit 2 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/decoding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/earley/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/finetune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/async_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/configs/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/mtop/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/lispress_v2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/cosql/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/finetune/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/index/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .index import Candidate, DynamicIndex, Index, Query 5 | -------------------------------------------------------------------------------- /env_install.sh: -------------------------------------------------------------------------------- 1 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 2 | pip install transformers datasets 3 | pip install jsons appdirs blobfile cached-property httpx typer whoosh more_itertools jupyter openai 4 | pip install --upgrade protobuf==3.20.0 -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | def is_skippable(string: str): 5 | """A string is skippable if it's empty or begins with a '#'""" 6 | return not string or string[0] == "#" 7 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | remote_theme: pages-themes/cayman@v0.2.0 2 | plugins: 3 | - jekyll-remote-theme # add this line to the plugins list if you already have one 4 | title: Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification 5 | description: "Jiayi Pan, Glen Chou, Dmitry Berenson" 6 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/util/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os # pylint: disable=unused-import 5 | from typing import Union 6 | 7 | # This can be used to annotate arguments that are supposed to be file paths. 8 | StrPath = Union[str, "os.PathLike[str]"] 9 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/sql_datum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | 6 | from semantic_parsing_with_constrained_lm.datum import FullDatum 7 | 8 | 9 | @dataclass(frozen=True) 10 | class SqlDatum(FullDatum): 11 | schema_name: str 12 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Tuple 5 | 6 | from semantic_parsing_with_constrained_lm.scfg.parser.token import SCFGToken 7 | 8 | Nonterminal = str 9 | # An Alias is just another name for a nonterminal. 10 | Alias = str 11 | 12 | 13 | Expansion = Tuple[SCFGToken, ...] 14 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/mtop/grammar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import functools 5 | 6 | from semantic_parsing_with_constrained_lm.domains import dfa_grammar_utils 7 | 8 | create_partial_parse_builder = functools.partial( 9 | dfa_grammar_utils.create_partial_parse_builder, 10 | utterance_nonterm_name="any_char_star", 11 | ) 12 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/util/unit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | class Unit: 5 | """ 6 | Analogue of Scala Unit type that has a single instance UNIT. Can be used as type 7 | placeholder. Similar to None but can be used where None doesn't work. 8 | """ 9 | 10 | def __repr__(self) -> str: 11 | return "Unit" 12 | 13 | 14 | UNIT = Unit() 15 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/entities.scfg: -------------------------------------------------------------------------------- 1 | personname -> quoted , " #(PersonName " quoted ")" 2 | string -> quoted , " #(String " quoted ")" 3 | respondcomment -> quoted , " #(RespondComment " quoted ")" 4 | locationkeyphrase -> quoted , " #(LocationKeyphrase " quoted ")" 5 | path -> quoted , " #(Path " quoted ")" 6 | 7 | list_path_ -> "(empty list)" , " #(List[Path] [])" 8 | list_recipient_ -> "(empty recipient list)" , " #(List[Recipient] [])" 9 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/enum_wrappers.scfg: -------------------------------------------------------------------------------- 1 | holiday -> holiday_entity , " #(Holiday \"" holiday_entity "\")" 2 | placefeature -> place_feature_entity , " #(PlaceFeature \"" place_feature_entity "\")" 3 | weatherquantifier -> weather_quantifier_entity , " #(WeatherQuantifier \"" weather_quantifier_entity "\")" 4 | responsestatustype -> response_entity , " #(ResponseStatusType \"" response_entity "\")" 5 | number -> number_entity , " #(Number" number_entity ")" 6 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/grammar/table_columns.scfg: -------------------------------------------------------------------------------- 1 | ## these get auto-gen'd per schema 2 | 3 | # schema_name -> any_name 4 | # table_name -> any_name 5 | # table_alias -> any_name 6 | # column_name -> any_name 7 | # schema_table_dot -> schema_dot_ws? table_dot 8 | # possibly_qualified_column_name -> schema_table_dot? column_name 9 | 10 | table_alias -> T "1" 11 | table_alias -> T "2" 12 | table_alias -> T "3" 13 | table_alias -> T "4" 14 | table_alias -> T "5" 15 | 16 | table_name -> table_alias 17 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/cosql/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from pathlib import Path 5 | from semantic_parsing_with_constrained_lm.paths import DOMAINS_DIR 6 | 7 | 8 | # TODO: Don't hardcode 9 | COSQL_DIR = Path("data/cosql_dataset") 10 | 11 | SCHEMAS_FILE = COSQL_DIR / "tables.json" 12 | 13 | SQL_STATE_TRACKING_DIR = COSQL_DIR / "sql_state_tracking" 14 | TRAIN_DATA_FILE = SQL_STATE_TRACKING_DIR / "cosql_train.json" 15 | 16 | SQL_GRAMMAR_DIR = DOMAINS_DIR / "sql" / "grammar" 17 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/sequence_creator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum 7 | 8 | 9 | class SequenceCreator(ABC): 10 | @abstractmethod 11 | def create_sequence(self, datum: BenchClampDatum) -> str: 12 | pass 13 | 14 | 15 | class IdentitySequenceCreator(SequenceCreator): 16 | def create_sequence(self, datum: BenchClampDatum) -> str: 17 | return datum.utterance 18 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/cache.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Optional 6 | 7 | 8 | class CacheClient(ABC): 9 | async def __aenter__(self): 10 | pass 11 | 12 | async def __aexit__(self, exc_type, exc_value, traceback): 13 | pass 14 | 15 | @abstractmethod 16 | async def get(self, args: dict) -> Optional[dict]: 17 | pass 18 | 19 | @abstractmethod 20 | async def upload(self, args: dict, result: dict) -> None: 21 | pass 22 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/quoted.scfg: -------------------------------------------------------------------------------- 1 | # Quoted strings *do* begin with a space in this grammar. 2 | # For example, `create event with " Rose"`. 3 | # The space has to be a regex, b/c it gets consumed by CopyTokens, 4 | # and it has to not be inside nonquoteplus, because it doesn't 5 | # appear on the plan side. 6 | quoted -> "\"" / / nonquoteplus "\"" , "\"" nonquoteplus "\"" 7 | 8 | # matches one or more characters that are not double quotes 9 | nonquoteplus -> /[^"]/ nonquotestar 10 | 11 | # matches zero or more characters that are not double quotes 12 | nonquotestar -> /[^"]/ nonquotestar 13 | nonquotestar -> empty 14 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/string_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Iterable 5 | 6 | 7 | def detokenize(tokens: Iterable[str], with_treebank: bool = True) -> str: 8 | """ 9 | Given a list of tokens, join them together into a string. 10 | with_treebank = True is typically used when rendering utterances, so we don't need to deal with things like 11 | "andrew's" 12 | with_treebank = False is typically for rendering express. 13 | """ 14 | if with_treebank: 15 | return " ".join(tokens).replace(" ", " ") 16 | 17 | return "".join(tokens) 18 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.10.4 64-bit", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "name": "python", 19 | "version": "3.10.4" 20 | }, 21 | "orig_nbformat": 4, 22 | "vscode": { 23 | "interpreter": { 24 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 25 | } 26 | } 27 | }, 28 | "nbformat": 4, 29 | "nbformat_minor": 2 30 | } 31 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/grammar/whitespace.scfg: -------------------------------------------------------------------------------- 1 | # This subgrammar matches (non-newline) whitespace sequences on the left side, 2 | # and returns a single " " on the right side 3 | # Newlines are turned into "\n" 4 | 5 | # 1 or more whitespace chars, returns SPACE 6 | ws -> ws_char ws_star_empty 7 | 8 | # 0 or more whitespace chars, returns SPACE 9 | ws_star -> #e , SPACE 10 | ws_star -> ws 11 | 12 | # matches any whitespace sequence and returns #e 13 | ws_star_empty -> ws_char ws_star , #e 14 | ws_star_empty -> #e 15 | 16 | 17 | SPACE -> " " 18 | 19 | # matches any whitespace char and returns SPACE 20 | ws_char -> SPACE 21 | ws_char -> "\u000B", SPACE 22 | ws_char -> TAB , SPACE 23 | ws_char -> NEWLINE , SPACE 24 | 25 | NEWLINE -> "\r" , "\n" 26 | NEWLINE -> "\n" 27 | 28 | TAB -> "\t" 29 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/index/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Generic, Iterable, Tuple, TypeVar 6 | 7 | Key = TypeVar("Key") 8 | Query = TypeVar("Query") 9 | Candidate = TypeVar("Candidate") 10 | 11 | 12 | class Index(Generic[Key, Query], ABC): 13 | """ 14 | Encapsulates any index that can be searched over. 15 | It can either be a sparse index (e.g. powered by Whoosh), or a dense index (e.g. powered by FAISS). 16 | """ 17 | 18 | @abstractmethod 19 | def search(self, query: Query, top_k: int) -> Iterable[Tuple[Key, float]]: 20 | raise NotImplementedError 21 | 22 | 23 | class DynamicIndex(Generic[Key, Query, Candidate], Index[Key, Query]): 24 | """ 25 | Any index that supports dynamic addition to the set of candidates. 26 | """ 27 | 28 | @abstractmethod 29 | def add(self, candidates: Iterable[Candidate]) -> None: 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/decoding/trie_partial_parse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import dataclasses 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | 10 | from semantic_parsing_with_constrained_lm.util.trie import Trie 11 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse 12 | 13 | 14 | @dataclass 15 | class TriePartialParse(PartialParse): 16 | trie: Trie[int] 17 | tokens: Tuple[int, ...] = () 18 | 19 | def allowed_next( 20 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None 21 | ) -> Tuple[torch.Tensor, bool]: 22 | allowed, is_complete = self.trie.prefix_next(self.tokens) 23 | return torch.tensor(sorted(allowed), dtype=torch.long), is_complete 24 | 25 | def append(self, token: int) -> "PartialParse": 26 | """Return a new PartialParse creatoted by appending this token.""" 27 | return dataclasses.replace(self, tokens=self.tokens + (token,)) 28 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split0.canonical.json: -------------------------------------------------------------------------------- 1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}} -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split1.canonical.json: -------------------------------------------------------------------------------- 1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}} -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split2.canonical.json: -------------------------------------------------------------------------------- 1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}} -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split3.canonical.json: -------------------------------------------------------------------------------- 1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}} -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-golden-cross0-split4.canonical.json: -------------------------------------------------------------------------------- 1 | {"globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"}, "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": {"formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"}} -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/lispress_v2/sequence_creator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum 5 | from semantic_parsing_with_constrained_lm.sequence_creator import SequenceCreator 6 | 7 | 8 | class LastAgentUtterance(SequenceCreator): 9 | def create_sequence(self, datum: BenchClampDatum) -> str: 10 | last_agent_utterance = ( 11 | datum.last_agent_utterance if datum.last_agent_utterance is not None else "" 12 | ) 13 | return " | ".join([last_agent_utterance, datum.utterance]) 14 | 15 | 16 | class LastUserAgentUtterance(SequenceCreator): 17 | def create_sequence(self, datum: BenchClampDatum) -> str: 18 | last_agent_utterance = ( 19 | datum.last_agent_utterance if datum.last_agent_utterance is not None else "" 20 | ) 21 | last_user_utterance = ( 22 | datum.last_user_utterance if datum.last_user_utterance is not None else "" 23 | ) 24 | return " | ".join([last_user_utterance, last_agent_utterance, datum.utterance]) 25 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-syn.canonical.json: -------------------------------------------------------------------------------- 1 | { 2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": { 3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )" 4 | }, 5 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": { 6 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )" 7 | }, 8 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": { 9 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )" 10 | }, 11 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": { 12 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )" 13 | }, 14 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": { 15 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )" 16 | } 17 | } -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/kept/pick-syn.canonical.json: -------------------------------------------------------------------------------- 1 | { 2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": { 3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )" 4 | }, 5 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": { 6 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )" 7 | }, 8 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": { 9 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )" 10 | }, 11 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": { 12 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )" 13 | }, 14 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": { 15 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )" 16 | } 17 | } -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/finetune/download_huggingface_lms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from pathlib import Path 5 | 6 | from transformers import ( 7 | BartForConditionalGeneration, 8 | BartTokenizer, 9 | RobertaTokenizer, 10 | T5ForConditionalGeneration, 11 | T5Tokenizer, 12 | ) 13 | 14 | CLAMP_PRETRAINED_MODEL_DIR = Path("huggingface_models/") 15 | 16 | 17 | def save_model_and_tokenizer(model, tokenizer, save_dir: Path) -> None: 18 | save_dir.mkdir(exist_ok=True, parents=True) 19 | model.save_pretrained(save_dir) 20 | tokenizer.save_pretrained(save_dir) 21 | 22 | 23 | def main(): 24 | # T5 25 | # Bart 26 | for model_id, huggingface_model_id in [ 27 | ("bart-large", "facebook/bart-large"), 28 | ]: 29 | print(f"Downloading {model_id} ...") 30 | model = BartForConditionalGeneration.from_pretrained(huggingface_model_id) 31 | tokenizer = BartTokenizer.from_pretrained(huggingface_model_id) 32 | save_model_and_tokenizer( 33 | model, tokenizer, CLAMP_PRETRAINED_MODEL_DIR / model_id 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-syn-aug.canonical.json: -------------------------------------------------------------------------------- 1 | { 2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": { 3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )" 4 | }, 5 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": { 6 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )" 7 | }, 8 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": { 9 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )" 10 | }, 11 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": { 12 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )" 13 | }, 14 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": { 15 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )" 16 | } 17 | } -------------------------------------------------------------------------------- /datasets/pick-and-place/canonical.json: -------------------------------------------------------------------------------- 1 | { 2 | "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )": { 3 | "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", 4 | "raw": "G & U S ! A F A" 5 | }, 6 | "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )": { 7 | "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", 8 | "raw": "G & U S ! R F R" 9 | }, 10 | "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )": { 11 | "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", 12 | "raw": "G & U S ! B F B" 13 | }, 14 | "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )": { 15 | "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", 16 | "raw": "G & U S ! Y F Y" 17 | }, 18 | "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )": { 19 | "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", 20 | "raw": "G & U S ! C F C" 21 | } 22 | } -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/result.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | from typing import Dict, List, Optional, Tuple 6 | 7 | from semantic_parsing_with_constrained_lm.model import ModelResult 8 | 9 | 10 | @dataclass(frozen=True, eq=True) 11 | class DatumResult: 12 | """CLAMP predictions and results for a single datum""" 13 | 14 | # Test datum utterance 15 | test_datum_natural: str 16 | 17 | # Text and cost of each sequence in the final beam 18 | results: List[ModelResult] 19 | 20 | # Text of each sequence in the final beam 21 | # (Duplicated from `results`; maintained here only 22 | # for backwards compatibility. May be removed later.) 23 | outputs: List[str] 24 | 25 | # The metrics dictionary containing the main results 26 | metrics: Dict[str, Optional[str]] 27 | 28 | # Other (optional) test datum fields 29 | test_datum_id: Optional[str] = None 30 | test_datum_turn_part_index: Optional[int] = None 31 | test_datum_agent_context: Optional[str] = None 32 | test_datum_canonical: Optional[str] = None 33 | 34 | # Token-level log probabilities for each sequence in the final beam 35 | # (Not yet implemented) 36 | token_logprobs: Optional[List[List[Tuple[str, float]]]] = None 37 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/scfg_grammar.lark: -------------------------------------------------------------------------------- 1 | %import common.WS 2 | %import common.WORD 3 | %import common.ESCAPED_STRING 4 | 5 | %ignore WS 6 | 7 | start: sync_rule 8 | | mirrored_rule 9 | | macro_rule 10 | | utterance_rule 11 | 12 | start_for_test: sync_rule 13 | | macro_rule 14 | | utterance_rule 15 | | plan_expansion 16 | 17 | sync_rule: rule "->" utterance_expansions "," plan_expansion 18 | mirrored_rule: rule "->" utterance_expansion 19 | macro_rule: macro_def "2>" plan_expansion 20 | utterance_rule: rule "1>" utterance_expansions 21 | 22 | utterance_expansions: utterance_expansion ("|" utterance_expansion)* 23 | 24 | plan_expansion: (token | macro_apply)+ 25 | utterance_expansion: token+ 26 | 27 | token: terminal 28 | | optional_terminal 29 | | nonterminal 30 | | optional_nonterminal 31 | | empty 32 | | regex 33 | 34 | terminal: terminal_string 35 | optional_terminal: terminal_string "?" 36 | 37 | nonterminal: _name 38 | optional_nonterminal: _name "?" 39 | 40 | rule: _name 41 | 42 | macro_def: _name "(" (_name ("," _name)* ","?)? ")" 43 | | _name 44 | macro_apply: _name "(" (_macro_arg ("," _macro_arg)* ","?)? ")" 45 | _macro_arg: nonterminal | terminal | macro_apply | empty 46 | 47 | _name: /[a-zA-Z][_a-zA-Z0-9]*/ 48 | regex: /\/[^\/]+\// 49 | 50 | ?terminal_string: ESCAPED_STRING 51 | 52 | empty: "#e" 53 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/earley/cfg.lark: -------------------------------------------------------------------------------- 1 | // Based on scfg_grammar.lark 2 | %import common.WS 3 | %import common.ESCAPED_STRING 4 | %import common._STRING_ESC_INNER 5 | %import common.INT 6 | 7 | %ignore WS 8 | 9 | _COMMENT: /#[^\n]*/ 10 | start: (rule | _COMMENT)* 11 | 12 | rule: nonterminal_lhs "->" expansion 13 | 14 | expansion: alt ("|" alt)* 15 | 16 | ?alt: elem0+ 17 | | empty 18 | 19 | ?elem0: elem1 20 | | elem1 "?" -> optional 21 | | elem1 "*" -> star 22 | | elem1 "+" -> plus 23 | | elem1 "{" count "}" -> repeat_exact 24 | | elem1 "{" count ",}" -> repeat_min 25 | | elem1 "{," count "}" -> repeat_max 26 | | elem1 "{" count "," count "}" -> repeat_min_max 27 | 28 | ?elem1: nonterminal_rhs 29 | | terminal 30 | | "(" expansion ")" 31 | | "[" /[^\]]+/ "]" -> char_class 32 | | "[[" /[^\]]+/ "]--[" /[^\]]+/+ "]]" -> char_class_subtract 33 | 34 | HYPHEN: "-" 35 | RIGHT_SQUARE_BRACKET: "]" 36 | 37 | ESCAPED_SINGLE_QUOTED_STRING: "'" _STRING_ESC_INNER "'" 38 | terminal: ESCAPED_STRING | ESCAPED_SINGLE_QUOTED_STRING 39 | // TODO: #e conflicts with syntax for comments 40 | empty: "#e" 41 | nonterminal_lhs: _name 42 | nonterminal_rhs: _name 43 | _name: /[_a-zA-Z][_a-zA-Z0-9]*/ 44 | !count: INT 45 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/earley/unicode_categories_spans.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Provides data about which Unicode characters belong to which general category. 5 | 6 | The concept is explained here: 7 | - https://www.unicode.org/reports/tr44/#General_Category_Values 8 | - https://en.wikipedia.org/wiki/Unicode_character_property#General_Category 9 | 10 | unicode_categories.json was created by translating 11 | https://raw.githubusercontent.com/rust-lang/regex/258bdf798a14f50529c1665e84cc8a3a9e2c90fc/regex-syntax/src/unicode_tables/general_category.rs 12 | """ 13 | import functools 14 | import json 15 | from pathlib import Path 16 | from typing import Dict, List 17 | 18 | from semantic_parsing_with_constrained_lm.util.span import Span, SpanSet 19 | 20 | 21 | @functools.lru_cache(maxsize=None) 22 | def raw_data() -> Dict[str, List[List[int]]]: 23 | with open(Path(__file__).absolute().parent / "unicode_categories.json") as f: 24 | return json.load(f) 25 | 26 | 27 | @functools.lru_cache(maxsize=None) 28 | def category_to_span_set(name: str) -> SpanSet: 29 | """Returns the SpanSet for the category name. 30 | 31 | The SpanSet contains the Unicode code points for the corresponding general category. 32 | Only long names with underscores (e.g. "Letter", "Cased_Letter") are accepted.""" 33 | 34 | return SpanSet(Span(x, y) for x, y in raw_data()[name]) 35 | -------------------------------------------------------------------------------- /datasets/pick-and-place/train_seed.jsonl: -------------------------------------------------------------------------------- 1 | {"canonical": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "natural": "Look for and pick up any cubes and put them in crate."} 2 | {"canonical": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "natural": "Look for and pick up any non red cubes and put them in crate."} 3 | {"canonical": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "natural": "Look for and pick up any non blue cubes and put them in crate."} 4 | {"canonical": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "natural": "Look for and pick up any non yellow cubes and put them in crate."} 5 | {"canonical": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "natural": "Look for and pick up any non green cubes and put them in crate."} 6 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/dfa_grammar_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from semantic_parsing_with_constrained_lm.earley.grammar import DFAGrammar, Nonterm 5 | from semantic_parsing_with_constrained_lm.earley.specialization import SubstringIntersectingGrammarSpecializer 6 | from semantic_parsing_with_constrained_lm.datum import Datum 7 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse 8 | from semantic_parsing_with_constrained_lm.decoding.uint8_earley_partial_parse import ( 9 | UInt8EarleyPartialParse, 10 | UInt8GrammarTokenizerInfo, 11 | ) 12 | from semantic_parsing_with_constrained_lm.model import PartialParseBuilder 13 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 14 | 15 | 16 | def create_partial_parse_builder( 17 | grammar: DFAGrammar, tokenizer: ClampTokenizer, utterance_nonterm_name: str 18 | ) -> PartialParseBuilder[Datum]: 19 | specializer = SubstringIntersectingGrammarSpecializer( 20 | grammar, Nonterm(utterance_nonterm_name) 21 | ) 22 | tokens = UInt8GrammarTokenizerInfo.prepare_tokens_from_clamp_tokenizer(tokenizer) 23 | 24 | def builder(datum: Datum) -> PartialParse: 25 | specialized_grammar = specializer.specialize(datum.natural) 26 | return UInt8EarleyPartialParse.initial( 27 | UInt8GrammarTokenizerInfo(specialized_grammar, tokens) 28 | ) 29 | 30 | return builder 31 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/grammar/quoted.scfg: -------------------------------------------------------------------------------- 1 | # Single Quotes 2 | 3 | single_quoted -> "'" "%"? non_single_quote_star "%"? "'" 4 | 5 | # matches zero or more strings that are either not single quotes, or are two single quotes in a row 6 | non_single_quote_star -> #e 7 | non_single_quote_star -> /[^']/ non_single_quote_star 8 | 9 | # TODO: uncomment? 10 | # non_single_quote_star -> /[']/ /[']/ non_single_quote_star, /[']/ /[']/ non_single_quote_star 11 | 12 | 13 | # Double Quotes 14 | 15 | double_quoted -> "\"" "%"? non_double_quote_star "%"? "\"" 16 | 17 | # matches zero or more strings that are either not double quotes, or are two double quotes in a row 18 | non_double_quote_star -> #e 19 | non_double_quote_star -> /[^"]/ non_double_quote_star 20 | 21 | # TODO: uncomment? 22 | # non_double_quote_star -> /["]/ /["]/ non_double_quote_star, /["]/ /["]/ non_double_quote_star 23 | 24 | 25 | # Back ticks 26 | 27 | back_ticked -> "`" non_back_tick_star "`" 28 | 29 | # matches zero or more strings that are either not back ticks, or are two back ticks in a row 30 | non_back_tick_star -> #e 31 | non_back_tick_star -> /[^`]/ non_back_tick_star 32 | 33 | # TODO: uncomment? 34 | # non_back_tick_star -> /[`]/ /[`]/ non_back_tick_star, /[`]/ /[`]/ non_back_tick_star 35 | 36 | 37 | # Brackets 38 | 39 | bracketed -> "[" non_bracket_star "]" 40 | 41 | # matches zero or more strings that are not end brackets 42 | non_bracket_star -> #e 43 | non_bracket_star -> /[^\]]/ non_bracket_star 44 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/data/pick-syn.train.jsonl: -------------------------------------------------------------------------------- 1 | {"natural": "look for and pick up any cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )"} 2 | {"natural": "look for and pick up any non red cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )"} 3 | {"natural": "look for and pick up any non blue cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )"} 4 | {"natural": "look for and pick up any non yellow cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )"} 5 | {"natural": "look for and pick up any non green cubes and put them in crate", "canonical": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )"} 6 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/kept/pick-syn.train.jsonl: -------------------------------------------------------------------------------- 1 | {"canonical": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any cubes ) ) , finally ( any cubes ) ) )", "natural": "Look for and pick up any cubes and put them in crate."} 2 | {"canonical": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non red cubes ) ) , finally ( any non red cubes ) ) )", "natural": "Look for and pick up any non red cubes and put them in crate."} 3 | {"canonical": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non blue cubes ) ) , finally ( any non blue cubes ) ) )", "natural": "Look for and pick up any non blue cubes and put them in crate."} 4 | {"canonical": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non yellow cubes ) ) , finally ( any non yellow cubes ) ) )", "natural": "Look for and pick up any non yellow cubes and put them in crate."} 5 | {"canonical": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "formula": "globally ( and ( until ( scan , not ( any non green cubes ) ) , finally ( any non green cubes ) ) )", "natural": "Look for and pick up any non green cubes and put them in crate."} 6 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/earley/context_sensitive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Tools to help create context-sensitive grammars. 5 | # See tests/test_harbor/earley/context_sensitive_demos for some example uses. 6 | 7 | from abc import ABC, abstractmethod 8 | from dataclasses import dataclass 9 | from typing import Dict, Generic, Iterable, Set 10 | 11 | from semantic_parsing_with_constrained_lm.earley.grammar import DottedRule, Nonterm, RuleResult, Terminal 12 | 13 | 14 | class SelfExpandingNonterm(Generic[Terminal, RuleResult], Nonterm, ABC): 15 | @abstractmethod 16 | def get_expansions(self) -> Iterable[DottedRule[Terminal, RuleResult]]: 17 | """Get the set of expansions for this nonterminal. 18 | 19 | Child classes should contain additional fields which are used to 20 | determine the expansions created.""" 21 | pass 22 | 23 | 24 | @dataclass 25 | class DynamicGrammar(Generic[Terminal, RuleResult]): 26 | """A Grammar which knows how to use SelfExpandingNonterm to create expansions.""" 27 | 28 | root: Nonterm 29 | fixed_expansions: Dict[Nonterm, Set[DottedRule[Terminal, RuleResult]]] 30 | 31 | def get_expansions( 32 | self, nonterm: Nonterm 33 | ) -> Iterable[DottedRule[Terminal, RuleResult]]: 34 | result = self.fixed_expansions.get(nonterm) 35 | if result is not None: 36 | return result 37 | 38 | if isinstance(nonterm, SelfExpandingNonterm): 39 | return nonterm.get_expansions() 40 | 41 | return () 42 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | from pathlib import Path 6 | 7 | 8 | 9 | DOMAINS_DIR = Path(__file__).resolve().parent / "domains" 10 | 11 | CALFLOW_EXAMPLES_DIR = DOMAINS_DIR / "calflow/data" 12 | 13 | CALFLOW_GRAMMAR_DIR = DOMAINS_DIR / "calflow/grammar" 14 | 15 | 16 | # Paths used in data preparation for BenchClamp 17 | RUN_ON_AML = "AMLT_EXPERIMENT_NAME" in os.environ 18 | 19 | CLAMP_PRETRAINED_MODEL_DIR = ( 20 | Path("/mnt/default/huggingface_models/") 21 | if RUN_ON_AML 22 | else Path("huggingface_models/") 23 | ) 24 | 25 | CLAMP_DATA_DIR = ( 26 | Path("/mnt/default/clamp_data/") if RUN_ON_AML else Path("data") 27 | ) 28 | 29 | OVERNIGHT_DATA_DIR = CLAMP_DATA_DIR / "overnight" 30 | 31 | LTL_DATA_DIR = CLAMP_DATA_DIR / "ltl" 32 | 33 | BENCH_CLAMP_DATA_DIR_ROOT = CLAMP_DATA_DIR / "benchclamp" 34 | 35 | BENCH_CLAMP_RAW_DATA_DIR = BENCH_CLAMP_DATA_DIR_ROOT / "raw" 36 | 37 | BENCH_CLAMP_PROCESSED_DATA_DIR = BENCH_CLAMP_DATA_DIR_ROOT / "processed" 38 | 39 | BENCH_CLAMP_GRAMMAR_DATA_DIR = BENCH_CLAMP_DATA_DIR_ROOT / "grammar" 40 | 41 | 42 | # Paths for users of BenchClamp. Kept as strings since Path does not work well with network paths. 43 | CLAMP_DATA_DIR_AZURE = "https://benchclamp.blob.core.windows.net/benchclamp" 44 | 45 | OVERNIGHT_DATA_DIR_AZURE = CLAMP_DATA_DIR_AZURE + "/overnight" 46 | 47 | BENCH_CLAMP_DATA_DIR_ROOT_AZURE = CLAMP_DATA_DIR_AZURE + "/benchclamp" 48 | 49 | BENCH_CLAMP_PROCESSED_DATA_DIR_AZURE = BENCH_CLAMP_DATA_DIR_ROOT_AZURE + "/processed" 50 | 51 | BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE = BENCH_CLAMP_DATA_DIR_ROOT_AZURE + "/grammar" 52 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/sequence_creator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing_extensions import Literal 5 | 6 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum 7 | from semantic_parsing_with_constrained_lm.sequence_creator import SequenceCreator 8 | from semantic_parsing_with_constrained_lm.tokenization import GPT2ClampTokenizer 9 | 10 | 11 | class CoSqlUtterance(SequenceCreator): 12 | def __init__( 13 | self, use_db_val: bool, past_utterances: Literal["none", "one", "all"] 14 | ): 15 | self.use_db_val = use_db_val 16 | self.past_utterances = past_utterances 17 | self.gpt2_tokenizer = GPT2ClampTokenizer.from_pretrained("gpt2") 18 | 19 | def create_sequence(self, datum: BenchClampDatum) -> str: 20 | db_schema = ( 21 | datum.db_schema_with_val if self.use_db_val else datum.db_schema_without_val 22 | ) 23 | all_past_utterances = datum.utterance.split(" | ") 24 | current_utterance = all_past_utterances[-1] 25 | if self.past_utterances == "none" or len(all_past_utterances) == 1: 26 | past_utterances = "" 27 | elif self.past_utterances == "one": 28 | past_utterances = all_past_utterances[-2] 29 | else: 30 | past_utterances = " | ".join(all_past_utterances[:-1]) 31 | 32 | sequence = " , ".join([past_utterances, db_schema, current_utterance]) # type: ignore 33 | sequence_token_ids = self.gpt2_tokenizer.encode(sequence) 34 | # start_index = max(0, len(sequence_token_ids) - 1000) 35 | start_index = 0 36 | return self.gpt2_tokenizer.decode(sequence_token_ids[start_index:]) 37 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/datum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | from typing import Optional, TypeVar 6 | 7 | 8 | @dataclass(frozen=True, eq=True) 9 | class Datum: 10 | dialogue_id: Optional[str] 11 | turn_part_index: Optional[int] 12 | agent_context: Optional[str] 13 | natural: str 14 | 15 | 16 | @dataclass(frozen=True, eq=True) 17 | class FullDatum(Datum): 18 | canonical: str 19 | 20 | 21 | # Not contravariant since it is produced in a DataRetriever. 22 | # Discussions at https://semanticmachines.slack.com/archives/CM88KH6EN/p1654553264411409 23 | FullDatumSub = TypeVar("FullDatumSub", bound=FullDatum) 24 | # Contravariant since it is ingested by either DataRetriever, DataFilter, or PromptBuilder, but never produced 25 | DatumSub = TypeVar("DatumSub", bound=Datum, contravariant=True) 26 | 27 | 28 | @dataclass(frozen=True, eq=True) 29 | class BenchClampDatum: 30 | """ 31 | Class to hold all possible information for each instance in BenchCLAMP. This class is used to generate, read 32 | and write BenchCLAMP data files. We distill it to FullDatum before using training or evaluation. 33 | Fields only used for CalFlow, TreeDST: last_agent_utterance, last_user_utterance, last_plan 34 | Fields only used for Spider and CoSQL: schema_name, db_schema_without_val, db_schema_with_val 35 | """ 36 | 37 | dialogue_id: Optional[str] 38 | turn_part_index: Optional[int] 39 | utterance: str 40 | plan: str 41 | last_agent_utterance: Optional[str] = None 42 | last_user_utterance: Optional[str] = None 43 | last_plan: Optional[str] = None 44 | schema_name: Optional[str] = None 45 | db_schema_without_val: Optional[str] = None 46 | db_schema_with_val: Optional[str] = None 47 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Computes an linear fit for the number of tokens in the input and output. 5 | 6 | This can be used to set a good parameters for the `max_steps` parameter in beam search. 7 | 8 | The script does 10-fold cross-validation to find a slope and intercept which: 9 | - minimizes the mean number of excess steps (i.e. predicted length - gold length) 10 | - only very rarely predicts a length which is smaller than the gold length 11 | 12 | Example invocation: 13 | python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \ 14 | --data-path semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \ 15 | --tokenizer facebook/bart-large \ 16 | --output-type canonicalUtterance # or "lispress" 17 | """ 18 | import pathlib 19 | from typing import List, Tuple 20 | 21 | import typer 22 | from transformers import AutoTokenizer 23 | 24 | from semantic_parsing_with_constrained_lm.domains.calflow import ( 25 | CalflowOutputLanguage, 26 | read_calflow_jsonl, 27 | ) 28 | from semantic_parsing_with_constrained_lm.fit_max_steps import compute_and_print_fit 29 | 30 | 31 | def main( 32 | data_path: pathlib.Path = typer.Option(...), 33 | tokenizer: str = typer.Option(...), 34 | output_type: CalflowOutputLanguage = typer.Option(...), 35 | max_unreachable: int = typer.Option(1), 36 | ): 37 | t = AutoTokenizer.from_pretrained(tokenizer) 38 | 39 | pairs: List[Tuple[int, int]] = [] 40 | for datum in read_calflow_jsonl(data_path, output_type): 41 | num_input_tokens = len(t.tokenize(datum.natural)) 42 | if not datum.canonical: 43 | continue 44 | num_output_tokens = len(t.tokenize(datum.canonical)) + 1 45 | 46 | pairs.append((num_input_tokens, num_output_tokens)) 47 | 48 | compute_and_print_fit(pairs, 10, max_unreachable) 49 | 50 | 51 | if __name__ == "__main__": 52 | typer.run(main) 53 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/util/keydefaultdict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Adapted from https://stackoverflow.com/posts/2912455/revisions, 5 | # but with type annotations added. 6 | # pylint: disable=no-member,useless-super-delegation 7 | from typing import Any, Callable, DefaultDict, Iterable, Tuple, TypeVar 8 | 9 | K = TypeVar("K") 10 | V = TypeVar("V") 11 | 12 | 13 | class KeyDefaultDict(DefaultDict[K, V]): 14 | """ 15 | A version of defaultdict whose factory function that constructs a 16 | default value for a missing key takes that key as an argument. 17 | 18 | >>> d: KeyDefaultDict[int, str] = KeyDefaultDict(lambda k: "/" + str(k) + "/", {0: "zero"}) 19 | >>> d[3] = 'three' 20 | >>> d[0] 21 | 'zero' 22 | >>> d[3] 23 | 'three' 24 | >>> d[4] 25 | '/4/' 26 | >>> dict(d) 27 | {0: 'zero', 3: 'three', 4: '/4/'} 28 | """ 29 | 30 | def __init__(self, default_factory: Callable[[K], V], *args: Any, **kwargs: Any): 31 | super().__init__(None, *args, **kwargs) 32 | # store the default_factory in an attribute of a different name, to avoid an inheritance type error 33 | self.default_key_factory = default_factory 34 | 35 | def __missing__(self, key: K) -> V: 36 | """ 37 | Overrides the central method of `defaultdict` with one that calls 38 | `default_key_factory` on `key` instead of calling `default_factory` 39 | on 0 args. 40 | """ 41 | if self.default_key_factory is None: 42 | raise KeyError(key) 43 | ret = self[key] = self.default_key_factory(key) 44 | return ret 45 | 46 | def __repr__(self) -> str: 47 | """Prints `default_key_factory` instead of `default_factory`.""" 48 | return f"{self.__class__.__name__}({self.default_key_factory}, {dict.__repr__(self)})" 49 | 50 | # To avoid E1136 (unsubscriptable-object) pylint errors at call sites 51 | def __getitem__(self, item: K) -> V: 52 | return super().__getitem__(item) 53 | 54 | # To avoid E1136 (unsubscriptable-object) pylint errors at call sites 55 | def __setitem__(self, key: K, value: V) -> None: 56 | return super().__setitem__(key, value) 57 | 58 | def items(self) -> Iterable[Tuple[K, V]]: 59 | return super().items() 60 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/earley/specialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | 6 | from openfst_python import determinize, intersect 7 | 8 | from semantic_parsing_with_constrained_lm.earley.fsa import CompiledDFA 9 | from semantic_parsing_with_constrained_lm.earley.fsa_builders import compile_nfa, re_substring_utf8 10 | from semantic_parsing_with_constrained_lm.earley.grammar import DFADottedRule, DFAGrammar, Nonterm 11 | 12 | 13 | @dataclass 14 | class SubstringIntersectingGrammarSpecializer: 15 | """Creates grammars where the rule for string literals are modified to only allow substrings of a given string. 16 | 17 | The string is usually a user's utterance, where we are performing semantic parsing of that utterance 18 | and we know that any strings in the output program must be substrings of the utterance.""" 19 | 20 | base_grammar: DFAGrammar 21 | nonterm_to_intersect: Nonterm 22 | 23 | def __post_init__(self): 24 | assert self.nonterm_to_intersect in self.base_grammar.expansions 25 | 26 | def specialize(self, s: str) -> DFAGrammar: 27 | existing_rule = self.base_grammar.expansions[self.nonterm_to_intersect] 28 | substring_nfa = compile_nfa(re_substring_utf8(s)) 29 | assert substring_nfa.edge_indexer.num_indexed() == 0 30 | 31 | # Intersect the FSA which accepts any valid string literal with the FSA 32 | # for the substrings of `s`. This ensures that we don't take any 33 | # substrings of `s` which are incorrectly escaped, e.g. taking " rather 34 | # than \". 35 | # TODO: Have some checks that `s` has already been escaped correctly, 36 | # since otherwise we will end up excluding many substrings which should 37 | # have been included. 38 | new_dfa = CompiledDFA( 39 | determinize( 40 | intersect(existing_rule.dfa.fst, substring_nfa.fst).rmepsilon() 41 | ).minimize(), 42 | existing_rule.dfa.edge_indexer, 43 | ) 44 | new_rule = DFADottedRule(existing_rule.lhs, new_dfa, new_dfa.start_id) 45 | 46 | return DFAGrammar( 47 | self.base_grammar.root, 48 | {**self.base_grammar.expansions, self.nonterm_to_intersect: new_rule}, 49 | ) 50 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/decoding/partial_parse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | 9 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 10 | 11 | 12 | class PartialParse(ABC): 13 | @abstractmethod 14 | def allowed_next( 15 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None 16 | ) -> Tuple[Optional[torch.Tensor], bool]: 17 | """Returns possible ways to extend the current prefix. 18 | 19 | The Tensor is of type long and 1-dimensional, with no duplicate values, 20 | containing the IDs of the tokens that we could append. 21 | If it is None, then any token is allowed. 22 | The bool indicates whether we are allowed to terminate here. 23 | 24 | If ordered_ids and top_k are not None, this may optionally return only 25 | the first `top_k` token IDs from ordered_ids which comports with the 26 | grammar, instead of all such token IDs. 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def append(self, token: int) -> "PartialParse": 32 | """Return a new PartialParse created by appending this token.""" 33 | pass 34 | 35 | 36 | class NullPartialParse(PartialParse): 37 | """PartialParse which admits any sequence.""" 38 | 39 | def allowed_next( 40 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None 41 | ) -> Tuple[Optional[torch.Tensor], bool]: 42 | return None, True 43 | 44 | def append(self, token: int) -> "PartialParse": 45 | return self 46 | 47 | 48 | class StartsWithSpacePartialParse(PartialParse): 49 | def __init__(self, tokenizer: ClampTokenizer): 50 | valid_tokens = [] 51 | for utf8_token, token_id in tokenizer.utf8_token_to_id_map.items(): 52 | if utf8_token[0] == 32: 53 | valid_tokens.append(token_id) 54 | self.valid_tokens = torch.tensor(valid_tokens) 55 | 56 | def allowed_next( 57 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None 58 | ) -> Tuple[Optional[torch.Tensor], bool]: 59 | return self.valid_tokens, False 60 | 61 | def append(self, token: int) -> "PartialParse": 62 | return NullPartialParse() 63 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/macro.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | from typing import Dict, List, Optional, Tuple, cast 6 | 7 | from semantic_parsing_with_constrained_lm.scfg.parser.token import MacroToken, NonterminalToken, SCFGToken 8 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Expansion 9 | 10 | 11 | @dataclass(frozen=True) 12 | class Macro: 13 | name: str 14 | args: Tuple[str, ...] 15 | expansion: Expansion 16 | 17 | def apply_expression( 18 | self, macro_rules: Dict[str, "Macro"], args_to_bind: List[Expansion] 19 | ) -> Expansion: 20 | """ 21 | Apply this macro to the argument. Because all macros cannot use any variables outside of its definition, 22 | we do not need to pass in an environment. 23 | """ 24 | result: List[SCFGToken] = [] 25 | for token in self.expansion: 26 | result += eval_expression( 27 | macro_rules, token, dict(zip(self.args, args_to_bind)) 28 | ) 29 | return tuple(result) 30 | 31 | 32 | def eval_expression( 33 | macros: Dict[str, Macro], 34 | token: SCFGToken, 35 | env: Optional[Dict[str, Expansion]] = None, 36 | ) -> Expansion: 37 | """ 38 | Given a token, eval it. 39 | 40 | e.g. for macros defined as 41 | f(a) 2> "(" g(a) ")" 42 | g(b) 2> "(" b ")" 43 | h(c) 2> "(" c ")" 44 | 45 | given a macro call f(z), 46 | return "(" "(" z ")" ")" 47 | 48 | or a macro call g(h(z)), 49 | return "(" "(" z ")" ")" 50 | 51 | """ 52 | env = env if env else {} 53 | if isinstance(token, MacroToken): 54 | args_to_bind = [eval_expression(macros, arg, env) for arg in token.args] 55 | macro = macros[token.name] 56 | return macro.apply_expression(macros, args_to_bind) 57 | elif isinstance(token, NonterminalToken) and token.value in env: 58 | return env[token.value] 59 | 60 | return (token,) 61 | 62 | 63 | def expand_macros(macros: Dict[str, Macro], expansion: Expansion) -> Expansion: 64 | """ 65 | Return a new rule where the plan_rhs has been rewritten with all macros expanded. 66 | """ 67 | new_tokens: List[SCFGToken] = [] 68 | for token in expansion: 69 | if isinstance(token, MacroToken): 70 | new_tokens += eval_expression(macros, token) 71 | else: 72 | new_tokens.append(cast(SCFGToken, token)) 73 | 74 | return tuple(new_tokens) 75 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/overnight/create_benchclamp_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum 5 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import ( 6 | OVERNIGHT_DOMAINS, 7 | BenchClampDataset, 8 | ) 9 | from semantic_parsing_with_constrained_lm.domains.create_benchclamp_splits import ( 10 | create_benchclamp_splits, 11 | ) 12 | from semantic_parsing_with_constrained_lm.domains.overnight import OutputType, OvernightDataPieces 13 | from semantic_parsing_with_constrained_lm.paths import ( 14 | BENCH_CLAMP_PROCESSED_DATA_DIR, 15 | OVERNIGHT_DATA_DIR, 16 | ) 17 | 18 | 19 | def main(): 20 | for domain in OVERNIGHT_DOMAINS: 21 | overnight_pieces = OvernightDataPieces.from_dir( 22 | OVERNIGHT_DATA_DIR, 23 | is_dev=True, 24 | domain=domain, 25 | output_type=OutputType.MeaningRepresentation, 26 | simplify_logical_forms=True, 27 | ) 28 | train_data = overnight_pieces.train_data 29 | dev_data = overnight_pieces.test_data 30 | overnight_pieces = OvernightDataPieces.from_dir( 31 | OVERNIGHT_DATA_DIR, 32 | is_dev=False, 33 | domain=domain, 34 | output_type=OutputType.MeaningRepresentation, 35 | simplify_logical_forms=True, 36 | ) 37 | test_data = overnight_pieces.test_data 38 | 39 | train_benchclamp_data = [] 40 | dev_benchclamp_data = [] 41 | test_benchclamp_data = [] 42 | for data, benchclamp_data in [ 43 | (train_data, train_benchclamp_data), 44 | (dev_data, dev_benchclamp_data), 45 | (test_data, test_benchclamp_data), 46 | ]: 47 | for datum in data: 48 | benchclamp_data.append( 49 | BenchClampDatum( 50 | dialogue_id=datum.dialogue_id, 51 | turn_part_index=datum.turn_part_index, 52 | utterance=datum.natural, 53 | plan=datum.canonical, 54 | ) 55 | ) 56 | 57 | create_benchclamp_splits( 58 | train_benchclamp_data, 59 | dev_benchclamp_data, 60 | test_benchclamp_data, 61 | BENCH_CLAMP_PROCESSED_DATA_DIR / BenchClampDataset.Overnight.value / domain, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {% seo %} 6 | 7 | 8 | 9 | 10 | 11 | 12 | {% include head-custom.html %} 13 | Data-efficient LTL Learning 14 | 15 | 16 | Skip to the content. 17 | 18 | 28 | 29 |
30 | {{ content }} 31 | 32 | 38 |
39 | 40 | 41 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/earley/recognize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """ 5 | Methods that illustrates the simplest ways to use EarleyChart: 6 | as a recognizer of a token string, or as a generator of grammatical sentences. 7 | """ 8 | 9 | from typing import Iterable, Iterator, List, Sequence, cast 10 | 11 | from semantic_parsing_with_constrained_lm.earley.earley import EarleyLRChart, PackedForest 12 | from semantic_parsing_with_constrained_lm.earley.grammar import Grammar, RuleResult, Terminal 13 | from semantic_parsing_with_constrained_lm.earley.input import SequencePosition, SigmaStarTriePosition 14 | 15 | 16 | def parse( 17 | sentence: Iterable[Terminal], grammar: Grammar[Terminal, RuleResult] 18 | ) -> PackedForest[Terminal]: 19 | start_pos = SequencePosition(list(sentence)) 20 | chart = EarleyLRChart(grammar=grammar, start_pos=start_pos, use_backpointers=True) 21 | return chart.parse() 22 | 23 | 24 | def is_grammatical( 25 | tokens: Sequence[Terminal], grammar: Grammar[Terminal, RuleResult] 26 | ) -> bool: 27 | """ 28 | Tests whether the given input `tokens` are grammatical under `grammar`. 29 | """ 30 | start_pos = SequencePosition(tokens) 31 | chart = EarleyLRChart(grammar, start_pos, use_backpointers=False) 32 | for _ in chart.accepting_positions(): 33 | return True # we're grammatical if the iterator is non-empty 34 | return False 35 | 36 | 37 | def top_level_rule_results( 38 | tokens: Sequence[Terminal], grammar: Grammar[Terminal, RuleResult] 39 | ) -> Iterable[RuleResult]: 40 | """ 41 | Yields the RuleResults produced by the DottedRules from the `start` nonterminal. 42 | """ 43 | start_pos = SequencePosition(tokens) 44 | chart = EarleyLRChart(grammar, start_pos, use_backpointers=False) 45 | for end_pos in chart.accepting_positions(): 46 | for item, _ in chart.completed_items(grammar.root, start_pos, end_pos): 47 | rule_result = item.dotted_rule.is_final() 48 | assert rule_result is not None 49 | yield rule_result 50 | 51 | 52 | def enumerate_sentences( 53 | grammar: Grammar[Terminal, RuleResult] 54 | ) -> Iterator[List[Terminal]]: 55 | """ 56 | Yields grammatical sentences in length order (may not terminate). 57 | """ 58 | # root of a Σ* trie with string-labeled edges (as the grammar uses Terminal=str) 59 | start_pos = SigmaStarTriePosition[Terminal]() 60 | chart = EarleyLRChart(grammar, start_pos, use_backpointers=False) 61 | for pos in chart.accepting_positions(): # enumerate nodes in the Σ* trie 62 | # necessary because current typing isn't strong enough 63 | _pos = cast(SigmaStarTriePosition[Terminal], pos) 64 | yield _pos.prefix() 65 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/earley_grammar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # This file is deprecated. 5 | # TODO: Replace with Grammar[np.uint8] and UInt8EarleyPartialParse. 6 | from typing import Any, Dict, List, Tuple, Union 7 | 8 | from semantic_parsing_with_constrained_lm.earley.grammar import FixedGrammar, LinearDottedRule, Nonterm, Symbol 9 | from semantic_parsing_with_constrained_lm.scfg.char_grammar import START 10 | from semantic_parsing_with_constrained_lm.scfg.parser.token import ( 11 | EmptyToken, 12 | NonterminalToken, 13 | RegexToken, 14 | TerminalToken, 15 | ) 16 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Alias, Expansion, Nonterminal 17 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import GrammarRules 18 | 19 | CFTerminal = Union[str, RegexToken] 20 | 21 | 22 | class EarleyCFGrammar(FixedGrammar[CFTerminal, Any]): 23 | """A grammar for one of the two sides of an SCFG. 24 | 25 | Similar to CharGrammar, but it doesn't split up all terminals into single 26 | characters. 27 | """ 28 | 29 | @staticmethod 30 | def from_preprocessed_rules(rules: GrammarRules): 31 | aliased_grammar = { 32 | lhs: [(rhs, None) for rhs in rhss] for lhs, rhss in rules.items() 33 | } 34 | return EarleyCFGrammar.from_aliased_grammar(aliased_grammar) # type: ignore 35 | 36 | @staticmethod 37 | def from_aliased_grammar( 38 | grammar: Dict[Nonterminal, List[Tuple[Expansion, Alias]]] 39 | ) -> "EarleyCFGrammar": 40 | def convert(expansion: Expansion) -> Tuple[Symbol[CFTerminal], ...]: 41 | result = [] 42 | for token in expansion: 43 | if isinstance(token, TerminalToken): 44 | result.append(token.render()) 45 | elif isinstance(token, RegexToken): 46 | result.append(token) 47 | elif isinstance(token, NonterminalToken): 48 | result.append(Nonterm(token.value)) 49 | elif isinstance(token, EmptyToken): 50 | pass 51 | else: 52 | raise ValueError(token) 53 | return tuple(result) 54 | 55 | return EarleyCFGrammar( 56 | root=START, 57 | expansions={ 58 | Nonterm(origin): { 59 | # https://github.com/microsoft/pyright/issues/2962 60 | LinearDottedRule[CFTerminal].from_rule( 61 | Nonterm(origin), rhs=convert(rhs), alias=alias 62 | ) 63 | for rhs, alias in rhss 64 | } 65 | for origin, rhss in sorted(grammar.items()) 66 | }, 67 | ) 68 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/create_benchclamp_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum 5 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import ( 6 | LTL_DOMAINS, 7 | BenchClampDataset, 8 | ) 9 | from semantic_parsing_with_constrained_lm.domains.create_benchclamp_splits import create_benchclamp_splits 10 | from semantic_parsing_with_constrained_lm.domains.ltl import LTLOutputType, LTLDataPieces 11 | from semantic_parsing_with_constrained_lm.paths import ( 12 | BENCH_CLAMP_PROCESSED_DATA_DIR, 13 | LTL_DATA_DIR, 14 | ) 15 | 16 | 17 | def main(): 18 | for domain in LTL_DOMAINS: 19 | # domains in this case are "train-evaluation splits" 20 | 21 | """ 22 | Naming convention for ltl data files: 23 | - ltl_data.json: contains the canonical data 24 | - denotion looks like something weird, more like program result 25 | - used for building the trie constrained decoder 26 | - either ```canonical``` or ```formula``` 27 | - ltl.dev.jsonl 28 | - ltl.test.jsonl 29 | - ltl.train_with_dev.jsonl 30 | - ltl.train_without_dev.jsonl 31 | """ 32 | ltl_pieces = LTLDataPieces.from_dir( 33 | LTL_DATA_DIR, 34 | is_dev=True, 35 | domain=domain, 36 | output_type=LTLOutputType.MeaningRepresentation, 37 | simplify_logical_forms=True, 38 | ) 39 | train_data = ltl_pieces.train_data 40 | dev_data = ltl_pieces.test_data 41 | 42 | ltl_pieces = LTLDataPieces.from_dir( 43 | LTL_DATA_DIR, 44 | is_dev=False, 45 | domain=domain, 46 | output_type=LTLOutputType.MeaningRepresentation, 47 | simplify_logical_forms=True, 48 | ) 49 | test_data = ltl_pieces.test_data 50 | 51 | train_benchclamp_data = [] 52 | dev_benchclamp_data = [] 53 | test_benchclamp_data = [] 54 | 55 | for data, benchclamp_data in [ 56 | (train_data, train_benchclamp_data), 57 | (dev_data, dev_benchclamp_data), 58 | (test_data, test_benchclamp_data), 59 | ]: 60 | for datum in data: 61 | benchclamp_data.append( 62 | BenchClampDatum( 63 | dialogue_id=datum.dialogue_id, 64 | turn_part_index=datum.turn_part_index, 65 | utterance=datum.natural, 66 | plan=datum.canonical, 67 | ) 68 | ) 69 | 70 | create_benchclamp_splits( 71 | train_benchclamp_data, 72 | dev_benchclamp_data, 73 | test_benchclamp_data, 74 | BENCH_CLAMP_PROCESSED_DATA_DIR / BenchClampDataset.LTL.value / domain, 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow_eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """ Functions to help run Bart cold monster model for Calflow. """ 5 | import asyncio 6 | from pathlib import Path 7 | from typing import List 8 | 9 | import tqdm 10 | 11 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import PreprocessedGrammar 12 | from semantic_parsing_with_constrained_lm.configs.lib.calflow import make_semantic_parser_for_calflow 13 | from semantic_parsing_with_constrained_lm.datum import Datum, FullDatum 14 | from semantic_parsing_with_constrained_lm.domains.calflow import CalflowOutputLanguage 15 | from semantic_parsing_with_constrained_lm.lm import ClientType 16 | from semantic_parsing_with_constrained_lm.lm_bart import Seq2SeqBart 17 | from semantic_parsing_with_constrained_lm.model import BeamSearchSemanticParser, ModelResult 18 | from semantic_parsing_with_constrained_lm.train_model_setup import BartModelConfig 19 | 20 | 21 | def instantiate_bart_eval_model( 22 | model_loc: str, grammar_dir: str 23 | ) -> BeamSearchSemanticParser: 24 | preprocessed_grammar = PreprocessedGrammar.from_folder(grammar_dir) 25 | bart_model_config = BartModelConfig(model_id="Bart", model_loc=Path(model_loc)) 26 | model, tokenizer, _ = bart_model_config.setup_model() 27 | lm = Seq2SeqBart( 28 | pretrained_model_dir=model_loc, model=model, clamp_tokenizer=tokenizer 29 | ) 30 | beam_size = 2 31 | return make_semantic_parser_for_calflow( 32 | [], 33 | lm, 34 | use_gpt3=False, 35 | beam_size=beam_size, 36 | output_type=CalflowOutputLanguage.Canonical, 37 | client_type=ClientType.BART, 38 | preprocessed_grammar=preprocessed_grammar, 39 | constrained=True, 40 | num_examples_per_prompt=0, 41 | ) 42 | 43 | 44 | def predict(model: BeamSearchSemanticParser, user_utterance: str) -> List[str]: 45 | results: List[ModelResult] = asyncio.run( 46 | model.predict( 47 | Datum( 48 | natural=user_utterance, 49 | dialogue_id=None, 50 | turn_part_index=None, 51 | agent_context=None, 52 | ) 53 | ) 54 | ) 55 | if len(results) == 0: 56 | return [] 57 | return [result.text for result in results] 58 | 59 | 60 | def evaluate(eval_examples: List[FullDatum], model: BeamSearchSemanticParser) -> None: 61 | total = len(eval_examples) 62 | correct = 0 63 | for example in tqdm.tqdm(eval_examples): 64 | predicted = predict(model, example.natural)[0] 65 | if ( 66 | example.canonical is not None 67 | and example.canonical.strip() == predicted.strip() 68 | ): 69 | correct += 1 70 | else: 71 | print(f"Utterance: {example.natural}") 72 | print(f"Canonical: {example.canonical}") 73 | print(f"Predicted: {predicted}") 74 | 75 | acc = correct * 1.0 / total 76 | for _ in range(100): 77 | print(f"Accuracy = {correct} / {total} = {acc}") 78 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/finetune/check_for_unks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import importlib 5 | from typing import Callable, Dict, List, Optional, Set 6 | 7 | import typer 8 | 9 | from semantic_parsing_with_constrained_lm.lm import Surround 10 | from semantic_parsing_with_constrained_lm.run_exp import filter_exp_dict 11 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 12 | from semantic_parsing_with_constrained_lm.finetune.lm_finetune import TrainExperiment 13 | 14 | 15 | def find_and_record_unk_tokens( 16 | tokenizer: ClampTokenizer, 17 | surround: Surround, 18 | ids: List[int], 19 | orig: str, 20 | unk_tokens: Set[bytes], 21 | ) -> None: 22 | if surround.starts_with_space: 23 | orig = " " + orig 24 | tokens = tokenizer.tokenize(orig) 25 | ids = ids[len(surround.bos) : len(ids) - len(surround.eos)] 26 | assert len(tokens) == len(ids) 27 | 28 | for token_bytes, token_id in zip(tokens, ids): 29 | if token_id == tokenizer.unk_token_id: 30 | unk_tokens.add(token_bytes) 31 | 32 | 33 | def main( 34 | config_name: str = typer.Option(...), 35 | exp_names: Optional[List[str]] = typer.Option(None), 36 | exp_name_pattern: Optional[List[str]] = typer.Option(None), 37 | ): 38 | config_mod = importlib.import_module(config_name) 39 | exps: Dict[str, Callable[[], TrainExperiment]] = config_mod.build_config() # type: ignore 40 | filtered_exp_dict = filter_exp_dict(exps, exp_names, exp_name_pattern) 41 | 42 | for exp_name in filtered_exp_dict: 43 | exp = filtered_exp_dict[exp_name]() 44 | unk_id = exp.tokenizer.unk_token_id 45 | assert unk_id is not None 46 | 47 | train_dataset = exp.make_clamp_dataset(exp.train_data) 48 | num_unks_in_inputs = 0 49 | num_unks_in_labels = 0 50 | 51 | unk_tokens: Set[bytes] = set() 52 | 53 | for i, _ in enumerate(train_dataset): # type: ignore 54 | datum = train_dataset[i] 55 | input_ids = datum["input_ids"] 56 | labels = datum["labels"] 57 | 58 | if any(t == unk_id for t in input_ids): 59 | num_unks_in_inputs += 1 60 | find_and_record_unk_tokens( 61 | exp.tokenizer, 62 | exp.seq2seq_settings.input_surround, 63 | input_ids, 64 | exp.train_data[i].natural, 65 | unk_tokens, 66 | ) 67 | 68 | if any(t == unk_id for t in labels): 69 | num_unks_in_labels += 1 70 | find_and_record_unk_tokens( 71 | exp.tokenizer, 72 | exp.seq2seq_settings.output_surround, 73 | labels, 74 | exp.train_data[i].canonical, 75 | unk_tokens, 76 | ) 77 | 78 | print( 79 | f"{exp_name}: {num_unks_in_inputs}/{len(train_dataset)} unks in inputs, " 80 | f"{num_unks_in_labels}/{len(train_dataset)} unks in labels." 81 | ) 82 | print(f"unk_tokens = {unk_tokens}") 83 | 84 | 85 | if __name__ == "__main__": 86 | typer.run(main) 87 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/grammar/fluenter.scfg: -------------------------------------------------------------------------------- 1 | # non-empty results 2 | boolean -> "does there exist an " constraint_event_ empty , " (> (size (:results (FindEventWrapperWithDefaults :constraint" constraint_event_ "))) #(Number 0))" empty 3 | 4 | # singleton List[Path] 5 | list_path_ -> path , " (append #(List[Path] [])" path ")" 6 | list_path_ -> path , " (append (List.Nil)" path ")" 7 | 8 | # singleton List[Recipient] 9 | list_recipient_ -> recipient , " (append #(List[Recipient] [])" recipient ")" 10 | list_recipient_ -> recipient , " (append (List.Nil)" recipient ")" 11 | 12 | unit -> "Yes, do the " number " one" , " (Yield :output (Execute :intension (ChooseCreateEvent :index" number " :intension (refer (ActionIntensionConstraint)))))" 13 | 14 | constraint_list_attendee__ -> " with attendees" , " (negate (AlwaysFalseConstraint[List[Attendee]]))" 15 | 16 | constraint_event__constraint_event__args -> " with location unspecified" constraint_event__constraint_event__args , " :location (AlwaysFalseConstraint[LocationKeyphrase])" constraint_event__constraint_event__args 17 | 18 | # TODO: ambiguous with `(getIntraSalient (AlwaysTrueConstraint))`, does that matter? 19 | dynamic -> "that thing" , "(:item (getIntraSalient (AlwaysTrueConstraint[Dynamic])))" 20 | 21 | # comparisons that aren't "before"/"after" 22 | constraint_duration_ -> " longer than " duration , " (?>" duration ")" 23 | constraint_duration_ -> " no shorter than " duration , " (?>=" duration ")" 24 | constraint_duration_ -> " shorter than " duration , " (?<" duration ")" 25 | constraint_duration_ -> " no longer than " duration , " (?<=" duration ")" 26 | 27 | # update event without restating type 28 | constraint_event_wo_type -> constraint_event__constraint_event__args , " (Constraint[Event]" constraint_event__constraint_event__args ")" 29 | updateeventresponse -> "update " event " so it is" constraint_event_wo_type , " (UpdateWrapper" " :findArg" event " :updateArg" constraint_event_wo_type ")" 30 | updatecommitevent -> "update " eventid " so it is" constraint_event_wo_type , " (UpdatePreflightEventWrapper" " :id" eventid " :update" constraint_event_wo_type ")" 31 | 32 | # clobber event without restating type 33 | dynamic -> "Change my request so the " constraint_calflowintension_constraint_event___ " is" constraint_event_wo_type , " (ClobberWrapper" " :oldLocation" constraint_calflowintension_constraint_event___ " :new" constraint_event_wo_type ")" 34 | 35 | # update duration 36 | duration -> duration " longer" , " (addDurations (:duration (getIntraSalient (AlwaysTrueConstraint[Event])))" duration ")" 37 | 38 | # bare date constraint 39 | constraint_date_ -> "date" , " (Constraint[Date])" 40 | 41 | # nonEmptyBase 42 | non_empty_base_ -> constraint_event_ , " :nonEmptyBase" constraint_event_ 43 | constraint_event__constraint_event__args_a -> constraint_event__constraint_event__args , constraint_event__constraint_event__args 44 | constraint_event__constraint_event__args_b -> constraint_event__constraint_event__args , constraint_event__constraint_event__args 45 | constraint_event_ -> non_empty_base_ " but" constraint_event__constraint_event__args_a constraint_event__constraint_event__args_b , " (Constraint[Event]" constraint_event__constraint_event__args_a non_empty_base_ constraint_event__constraint_event__args_b ")" 46 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/token.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import json 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | from typing import List, Tuple 8 | 9 | import regex 10 | from cached_property import cached_property 11 | 12 | ASCII_CHARS: List[str] = [chr(i) for i in range(32, 127)] 13 | 14 | 15 | class SCFGToken(ABC): 16 | def render(self) -> str: 17 | """ 18 | How to render this token when generating it. In most cases, you can just render the underlying value. 19 | Sometimes you want to modify it, like in TerminalToken. 20 | """ 21 | return self.value 22 | 23 | @property 24 | @abstractmethod 25 | def value(self) -> str: 26 | """The underlying value of the token.""" 27 | pass 28 | 29 | @property 30 | def lark_value(self) -> str: 31 | return self.value 32 | 33 | 34 | class OptionableSCFGToken(SCFGToken, ABC): 35 | optional: bool 36 | 37 | 38 | @dataclass(frozen=True) 39 | class NonterminalToken(OptionableSCFGToken): 40 | underlying: str 41 | optional: bool 42 | 43 | @property 44 | def value(self): 45 | return self.underlying 46 | 47 | def is_regex(self): 48 | return self.underlying[0] == "/" 49 | 50 | 51 | @dataclass(frozen=True) 52 | class TerminalToken(OptionableSCFGToken): 53 | underlying: str 54 | optional: bool 55 | 56 | def render(self): 57 | """ 58 | Remove the outermost quotes and unescape the rest of the quotes. 59 | """ 60 | return json.loads(self.underlying) 61 | 62 | @property 63 | def value(self): 64 | return self.underlying 65 | 66 | @property 67 | def lark_value(self): 68 | return self.value + "i" 69 | 70 | 71 | @dataclass(frozen=True) 72 | class MacroToken(SCFGToken): 73 | name: str 74 | args: Tuple[SCFGToken, ...] 75 | 76 | @property 77 | def value(self): 78 | return f"{self.name}({','.join([a.value for a in self.args])})" 79 | 80 | 81 | @dataclass(frozen=True) 82 | class EmptyToken(SCFGToken): 83 | @property 84 | def value(self): 85 | return "" 86 | 87 | 88 | @dataclass(frozen=True) 89 | class RegexToken(NonterminalToken): 90 | prefix: str 91 | 92 | def render_matching_value(self, value: str) -> str: 93 | return self.prefix + value 94 | 95 | @property 96 | def value(self) -> str: 97 | return self.underlying 98 | 99 | @property 100 | def lark_value(self) -> str: 101 | if ( 102 | self.prefix 103 | ): # We need to have this condition because lark gets mad if you give it an empty token. 104 | return f'"{self.prefix}" ' + self.underlying 105 | else: 106 | return self.underlying 107 | 108 | @cached_property 109 | def compiled(self) -> "regex.Pattern[str]": 110 | assert self.underlying.startswith("/") and self.underlying.endswith("/") 111 | return regex.compile(self.underlying[1:-1]) 112 | 113 | @cached_property 114 | def ascii_chars(self) -> List[str]: 115 | return [c for c in ASCII_CHARS if self.compiled.match(c)] 116 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/util/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import contextlib 5 | import io 6 | import pathlib 7 | import sys 8 | import traceback 9 | import warnings 10 | from dataclasses import dataclass 11 | from io import SEEK_SET, TextIOBase 12 | from typing import Iterator, List, Optional, TextIO 13 | 14 | 15 | @dataclass 16 | class Tee(TextIOBase): 17 | """ "A write-only file-like object that forwards writes to `sinks`.""" 18 | 19 | sinks: List[TextIO] 20 | closed: bool = False 21 | 22 | def close(self) -> None: 23 | self.flush() 24 | self.closed = True 25 | 26 | def fileno(self) -> int: 27 | raise OSError 28 | 29 | def flush(self) -> None: 30 | for sink in self.sinks: 31 | sink.flush() 32 | 33 | def isatty(self) -> bool: 34 | return False 35 | 36 | def readable(self) -> bool: 37 | return False 38 | 39 | def readline(self, size=-1) -> str: # type: ignore 40 | raise io.UnsupportedOperation 41 | 42 | def readlines(self, hint=-1) -> List[str]: # type: ignore 43 | raise io.UnsupportedOperation 44 | 45 | def seek(self, offset, whence=SEEK_SET) -> int: 46 | raise io.UnsupportedOperation 47 | 48 | def seekable(self) -> bool: 49 | return False 50 | 51 | def tell(self) -> int: 52 | raise io.UnsupportedOperation 53 | 54 | def truncate(self, size=None): 55 | raise io.UnsupportedOperation 56 | 57 | def writable(self) -> bool: 58 | return True 59 | 60 | def writelines(self, lines: List[str]) -> None: 61 | for sink in self.sinks: 62 | sink.writelines(lines) 63 | 64 | @property 65 | def encoding(self) -> str: 66 | return self.sinks[0].encoding 67 | 68 | @property 69 | def errors(self) -> Optional[str]: 70 | return self.sinks[0].errors 71 | 72 | def detach(self) -> None: 73 | raise io.UnsupportedOperation 74 | 75 | def read(self, size=-1) -> str: 76 | raise io.UnsupportedOperation 77 | 78 | def write(self, s: str) -> int: 79 | results: List[int] = [] 80 | for sink in self.sinks: 81 | results.append(sink.write(s)) 82 | if not all(r == results[0] for r in results[1:]): 83 | warnings.warn("Sinks wrote different number of characters", ResourceWarning) 84 | return results[0] 85 | 86 | 87 | @contextlib.contextmanager 88 | def intercept_output( 89 | stdout_path: pathlib.Path, stderr_path: pathlib.Path 90 | ) -> Iterator[None]: 91 | """Write all stdout and stderr to both the screen and these files.""" 92 | 93 | with open(stdout_path, "a") as stdout_file, open(stderr_path, "a") as stderr_file: 94 | true_stdout = sys.stdout 95 | true_stderr = sys.stderr 96 | sys.stdout = Tee([true_stdout, stdout_file]) # type: ignore 97 | sys.stderr = Tee([true_stdout, stderr_file]) # type: ignore 98 | try: 99 | yield 100 | except: 101 | traceback.print_exc(file=stderr_file) 102 | raise 103 | finally: 104 | sys.stdout = true_stdout 105 | sys.stderr = true_stderr 106 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/cosql/seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """ 5 | ported from 6 | https://github.com/ElementAI/picard/blob/5ddd6cb9f74efca87d4604d5ddddc1b638459466/seq2seq/utils/cosql.py#L10 7 | Under the Apache 2 licence: 8 | https://github.com/ElementAI/picard/blob/main/LICENSE 9 | """ 10 | from typing import Any, Dict, List 11 | 12 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.content_encoder import ( 13 | get_database_matches, 14 | ) 15 | 16 | 17 | def get_input( 18 | utterances: List[str], serialized_schema: str, prefix: str, sep: str = " | " 19 | ) -> str: 20 | # "[prefix] [utterance n] [serialized schema] || [utterance n-1] | [utterance n-2] | ..." 21 | if len(utterances) > 1: 22 | reversed_utterance_head = ( 23 | utterance.strip() for utterance in reversed(utterances[:-1]) 24 | ) 25 | serialized_reversed_utterance_head = " || " + sep.join(reversed_utterance_head) 26 | else: 27 | serialized_reversed_utterance_head = "" 28 | return ( 29 | prefix 30 | + utterances[-1].strip() 31 | + " " 32 | + serialized_schema.strip() 33 | + serialized_reversed_utterance_head 34 | ) 35 | 36 | 37 | def serialize_schema( 38 | question: str, 39 | db_path: str, 40 | db_id: str, 41 | db_column_names: Dict[str, Any], 42 | db_table_names: List[str], 43 | schema_serialization_with_db_content: bool = True, 44 | ) -> str: 45 | # schema_serialization_with_db_id: bool = True 46 | # schema_serialization_randomized: bool = False 47 | # normalize_query: bool = True 48 | # schema_serialization_type = "peteshaw" 49 | # see https://github.com/google-research/language/blob/master/language/nqg/tasks/spider/append_schema.py#L42 50 | db_id_str = " | {db_id}" 51 | table_sep = "" 52 | table_str = " | {table} : {columns}" 53 | column_sep = " , " 54 | column_str_with_values = "{column} ( {values} )" 55 | column_str_without_values = "{column}" 56 | value_sep = " , " 57 | 58 | def get_column_str(table_name: str, column_name: str) -> str: 59 | column_name_str = column_name.lower() 60 | if schema_serialization_with_db_content: 61 | matches = get_database_matches( 62 | question=question, 63 | table_name=table_name, 64 | column_name=column_name, 65 | db_path=(db_path + "/" + db_id + "/" + db_id + ".sqlite"), 66 | ) 67 | if matches: 68 | return column_str_with_values.format( 69 | column=column_name_str, values=value_sep.join(matches) 70 | ) 71 | else: 72 | return column_str_without_values.format(column=column_name_str) 73 | else: 74 | return column_str_without_values.format(column=column_name_str) 75 | 76 | tables = [ 77 | table_str.format( 78 | table=table_name.lower(), 79 | columns=column_sep.join( 80 | map( 81 | lambda y: get_column_str( 82 | table_name=table_name, # pylint: disable=cell-var-from-loop 83 | column_name=y[1], 84 | ), 85 | filter( 86 | lambda y: y[0] 87 | == table_id, # pylint: disable=cell-var-from-loop 88 | zip( 89 | db_column_names["table_id"], db_column_names["column_name"] 90 | ), 91 | ), 92 | ) 93 | ), 94 | ) 95 | for table_id, table_name in enumerate(db_table_names) 96 | ] 97 | return db_id_str.format(db_id=db_id) + table_sep.join(tables) 98 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/cosql/grammar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Iterator, List 5 | 6 | from semantic_parsing_with_constrained_lm.util.types import StrPath 7 | from semantic_parsing_with_constrained_lm.scfg.parser.rule import ( 8 | Rule, 9 | SyncRule, 10 | mirrored_rule, 11 | nonterm, 12 | term, 13 | ) 14 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import PreprocessedGrammar 15 | from semantic_parsing_with_constrained_lm.scfg.scfg import SCFG 16 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.paths import SQL_GRAMMAR_DIR 17 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.schema import DbSchema 18 | 19 | # NTs in the SQL scfg 20 | SCHEMA_NAME = "schema_name" 21 | TABLE_NAME = "table_name" 22 | TABLE_ALIAS = "table_alias" 23 | COLUMN_NAME = "column_name" 24 | POSSIBLY_QUALIFIED = "possibly_qualified_column_name" 25 | DOT = "DOT" 26 | WS_STAR = "ws_star" 27 | SCHEMA_DOT_WS = "schema_dot_ws" 28 | 29 | 30 | def export_sync_rules(db: DbSchema) -> Iterator[SyncRule]: 31 | """ 32 | Exports schema-specific SCFG rules for the NTs 33 | `schema`, `table_name`, `column_name`, and `possibly_qualified_column_name`. 34 | Ensures that `(({s}.)?{t}.)?{c}` is accepted iff `c` is a column in table `t` in schema `s`. 35 | """ 36 | # helpers 37 | def to_lower(lhs: str, rhs: str): 38 | """case-insensitive match, normalize to lowercase""" 39 | return SyncRule( 40 | lhs=lhs, 41 | # case-insensitive 42 | utterance_rhss=( 43 | tuple(nonterm(c.upper()) if c.isalpha() else term(c) for c in rhs), 44 | ), 45 | # return lower case 46 | plan_rhs=(term(rhs.lower()),), 47 | ) 48 | 49 | rhs = (term(db.name),) 50 | yield mirrored_rule(SCHEMA_NAME, rhs) 51 | for table in db.tables: 52 | tbl_name = table.name 53 | my_tbl_name_nt = f"table_called_{tbl_name}" 54 | yield mirrored_rule(TABLE_NAME, (nonterm(my_tbl_name_nt),)) 55 | # case-insensitive, normalize to lowercase 56 | yield to_lower(lhs=my_tbl_name_nt, rhs=tbl_name) 57 | # allow table alias 58 | yield mirrored_rule(my_tbl_name_nt, (nonterm(TABLE_ALIAS),)) 59 | qualifier = f"qualifier_for_{tbl_name}" 60 | rhs3 = ( 61 | nonterm(SCHEMA_DOT_WS, optional=True), 62 | nonterm(my_tbl_name_nt), 63 | nonterm(WS_STAR), 64 | nonterm(DOT), 65 | nonterm(WS_STAR), 66 | ) 67 | yield mirrored_rule(qualifier, rhs3) 68 | col_nt = f"column_for_{tbl_name}" 69 | yield mirrored_rule(COLUMN_NAME, ((nonterm(col_nt)),)) 70 | # ensures that `table.column` is accepted iff column is in table 71 | poss_qual_rhs = (nonterm(qualifier, optional=True), (nonterm(col_nt))) 72 | yield mirrored_rule(POSSIBLY_QUALIFIED, poss_qual_rhs) 73 | for column in table.all_columns(): 74 | col_name = column.name 75 | # case-insensitive, normalize to lowercase 76 | yield to_lower(lhs=col_nt, rhs=col_name) 77 | 78 | # This will allow copying from database values. 79 | for val in db.values: 80 | t_val = (term(val),) 81 | yield mirrored_rule("non_single_quote_star", t_val) 82 | yield mirrored_rule("non_double_quote_star", t_val) 83 | 84 | 85 | def load_base_grammar(folder_path: StrPath = SQL_GRAMMAR_DIR) -> PreprocessedGrammar: 86 | return PreprocessedGrammar.from_folder(folder_path) 87 | 88 | 89 | def preprocessed_grammar_for_schema( 90 | db: DbSchema, base_grammar: PreprocessedGrammar 91 | ) -> PreprocessedGrammar: 92 | rules: List[Rule] = list(export_sync_rules(db)) 93 | return base_grammar.merge(PreprocessedGrammar.from_rules(rules)) 94 | 95 | 96 | def grammar_for_schema(db: DbSchema, base_grammar: PreprocessedGrammar) -> SCFG: 97 | return SCFG(preprocessed_grammar_for_schema(db, base_grammar)) 98 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/async_tools/batch_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import abc 5 | import dataclasses 6 | from abc import ABC 7 | from dataclasses import dataclass 8 | from typing import Callable, Dict, Generic, List, Tuple, TypeVar, Union 9 | 10 | from semantic_parsing_with_constrained_lm.util.unit import UNIT, Unit 11 | from semantic_parsing_with_constrained_lm.async_tools.limits import TimeoutBarrier 12 | 13 | I = TypeVar("I") 14 | O = TypeVar("O") 15 | 16 | 17 | class BatchMaker(Generic[I, O], ABC): 18 | """Identifies which elements could be batched together, and specifies how to do the batched operation. 19 | 20 | This should also derive from collections.abc.Hashable, but then it is not 21 | possible to use a dataclass to create a BatchMaker. 22 | """ 23 | 24 | @property 25 | @abc.abstractmethod 26 | def max_batch_size(self) -> int: 27 | """Maximum number of elements to batch together.""" 28 | pass 29 | 30 | @property 31 | @abc.abstractmethod 32 | def timeout(self) -> float: 33 | """Time to wait before running `execute`.""" 34 | pass 35 | 36 | @abc.abstractmethod 37 | async def execute(self, inputs: List[I]) -> O: 38 | """Batched operation on inputs.""" 39 | pass 40 | 41 | 42 | @dataclass 43 | class PendingContainer(Generic[I, O]): 44 | """Used inside BatchingHelper.""" 45 | 46 | batch_key: BatchMaker[I, O] 47 | batching_helper: "BatchingHelper[I, O]" 48 | 49 | inputs: List[I] = dataclasses.field(default_factory=list) 50 | barrier: TimeoutBarrier = dataclasses.field(init=False) 51 | result: Union[O, Unit] = UNIT 52 | 53 | def __post_init__(self): 54 | self.barrier = TimeoutBarrier( 55 | self.batch_key.max_batch_size, self.batch_key.timeout, self._execute 56 | ) 57 | 58 | async def enqueue_and_wait(self, inp: I) -> Tuple[O, int]: 59 | i = len(self.inputs) 60 | self.inputs.append(inp) 61 | await self.barrier.arrive_and_wait() 62 | assert not isinstance(self.result, Unit) 63 | return self.result, i 64 | 65 | @property 66 | def closed(self) -> bool: 67 | return bool(self.barrier.currently_releasing or self.result is not UNIT) 68 | 69 | async def _execute(self) -> None: 70 | self.result = await self.batch_key.execute(self.inputs) 71 | self.batching_helper._del_pending_container( # pylint: disable=protected-access 72 | self 73 | ) 74 | 75 | 76 | @dataclass 77 | class BatchingHelper(Generic[I, O]): 78 | """Helper for running functions on batched inputs.""" 79 | 80 | # Creates a BatchMaker from the input. Inputs with BatchMakers that compare equal are eligible for batching together. 81 | input_to_batch_maker: Callable[[I], BatchMaker[I, O]] 82 | 83 | # Pending operations per BatchMaker. 84 | pending: Dict[BatchMaker[I, O], PendingContainer[I, O]] = dataclasses.field( 85 | default_factory=dict 86 | ) 87 | 88 | async def execute(self, inp: I) -> Tuple[O, int]: 89 | """Given an input of type I, this class uses `batch_key_fn` to create a BatchKey. 90 | Inputs with the same BatchKey are coalesced together. 91 | After a certain number of inputs are collected, or a timeout passes, 92 | we run BatchKey.execute to produce a batched output of type O. 93 | 94 | Returns the output O with a batch index to locate the result for the input within O.""" 95 | 96 | batch_maker = self.input_to_batch_maker(inp) # type: ignore[call-arg] 97 | pending_container = self.pending.get(batch_maker) 98 | 99 | if pending_container is None or pending_container.closed: 100 | self.pending[batch_maker] = pending_container = PendingContainer( 101 | batch_maker, self 102 | ) 103 | 104 | return await pending_container.enqueue_and_wait(inp) 105 | 106 | def _del_pending_container(self, container: PendingContainer[I, O]): 107 | """When a PendingContainer is done executing, immediately remove it from `self.pending`""" 108 | if self.pending.get(container.batch_key) is container: 109 | del self.pending[container.batch_key] 110 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 | 5 | ## Intro 6 | 7 | To make robots accessible to a broad audience, it is critical to endow them with the ability to take universal modes of communication, like commands given in natural language, and extract a concrete desired task specification, defined using a formal language like linear temporal logic (LTL). 8 | 9 | In this paper, We present a learning-based approach to translate from natural language commands to LTL specifications with very limited human-labeled training data by leveraging Large Language Models (LLMs). Our model can translate natural language commands at 75% accuracy with about 12 annotations and when given full training data, achieves state-of-the-art performance. We also show how its outputs can be used to plan long-horizon, multi-stage tasks on a 12D quadrotor. 10 | 11 | 12 | 13 | ## Approach 14 | 15 | ![Semantic Parsing](https://i.imgur.com/EZEtGOZ.jpg) 16 | 17 | Given a predefined set of possible LTL formulas and atomic 18 | propositions, and up to one natural language annotation for 19 | each formula, we first translate these pre-defined formulas to 20 | (structured) English and then use the paraphrasing abilities of modern LLMs to synthesize a large corpus of diverse natural language commands. 21 | 22 | Given this corpus, we use the data to fine-tune an LLM. Here, we explore two variants, where for training labels we use 1) raw LTL formulas, or 2) a canonical form of the LTL formulas (an intermediate representation that is more similar to English). 23 | At evaluation time, we enforce the LLM’s output to be syntactically correct via constrained decoding. 24 | 25 | ## Result 26 | 27 | We evaluate our methods on three datasets, each associated with a different task and environment. We show that our model can translate natural language commands at 75% accuracy with about 12 annotations and when given full training data, constantly obtain state-of-the-art performance. In the paper, we also present comprehensive ablation studies that validate the necessity of each design decision we make. 28 | 29 | ### Results in low-data regimes 30 | 31 | In this setup, models are trained with only 12 human annotations at most, and evaluated on the entire original dataset. Our model significantly advances the state-of-the-art in this regime, achieving 75% average accuracy. 32 | 33 | | Model architecture | Training data | Test data | Drone Dataset | Cleanup Dataset | Pick Dataset | 34 | | ------------------------ | ------------- | ----------- | ------------- | --------------- | ------------ | 35 | | RNN | synthetic | full golden | 22.41 | 52.54 | 32.39 | 36 | | CopyNet | synthetic | full golden | 36.41 | 53.40 | 40.36 | 37 | | BART-FT-Raw (ours) | synthetic | full golden | **69.39** | **78.00** | **81.45** | 38 | | BART-FT-Canonical (ours) | synthetic | full golden | 68.38 | 77.90 | 78.23 | 39 | 40 | ### Results in standard data regimes 41 | 42 | In this setup, we follow the settings of previous works, where models are evaluated by five-fold cross-validation on the entire dataset. Our model consistently outperforms the state-of-the-art in this regime, with about 1% accuracy improvement on average. 43 | 44 | | Model architecture | Training data | Test data | Drone Dataset | Cleanup Dataset | Pick Dataset | 45 | | ------------------------ | ------------- | ---------- | ------------- | --------------- | ------------ | 46 | | RNN | 4/5 golden | 1/5 golden | 87.18 | 95.51 | 93.78 | 47 | | CopyNet | 4/5 golden | 1/5 golden | 88.97 | 95.47 | 93.14 | 48 | | BART-FT-Raw (ours) | 4/5 golden | 1/5 golden | 90.78 | **97.84** | **95.97** | 49 | | BART-FT-Canonical (ours) | 4/5 golden | 1/5 golden | **90.86** | 97.81 | 95.70 | 50 | 51 | ### Demo 52 | 53 | ![Demo](https://i.imgur.com/ynNTrpf.png) 54 | 55 | Finally, we also show how its outputs can be used to plan long-horizon, multi-stage tasks on a 12D quadrotor in simulation. 56 | 57 | ## Cite 58 | 59 | If you find this work useful, please cite our paper: 60 | 61 | ``` 62 | TO BE RELEASED 63 | ``` 64 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/data/ids_dev_100_uniform.txt: -------------------------------------------------------------------------------- 1 | a4363ada-270f-453a-92c3-72dd95f6b8c4,1 2 | c244d5d1-9728-447c-8cb1-875fc92a6bc1,3 3 | 20e6b0e7-f1e6-46e4-bb4a-d60bd9e11100,2 4 | 775152fc-214c-4528-af2d-e70ab7d52304,0 5 | fc861e3d-9804-4942-91fd-5b292cf9aff7,1 6 | 81cf24a3-88ed-4545-b3ee-53168f2d2caf,3 7 | 0f29448a-3bd2-4271-99ce-d44428b2d638,0 8 | 0dc2c70e-35e8-4012-a170-196a6615aa70,2 9 | d746be41-3bb5-497b-a9b0-b264be569990,0 10 | 3aaa3b10-e339-4b31-b1fb-afcc42f8770f,0 11 | 67e98197-84a8-4434-84d6-28fc38a917cc,6 12 | b93552d6-b070-4e0a-9e9f-49681dc2c594,4 13 | 0dc438b5-b8be-4502-bc90-0f732589edfb,2 14 | cb13f9a2-e233-4711-8895-fb5620785880,1 15 | fcca3bce-e680-4685-af40-d979661842d8,0 16 | 9ed7b59f-f57d-4c71-8010-9b96d1aa993a,1 17 | 0771eed4-3107-4485-be63-79247f3f984f,4 18 | c4e474fb-4f6c-4182-b688-604f6d1a9fdf,2 19 | 15b0f0e5-023c-4a2d-9f66-4f47c9b95e51,4 20 | ce50f9e6-d765-438a-b4d7-bc940facd561,2 21 | df190a44-c2de-4e66-a1a4-b12c092d7e86,1 22 | 3f4520ea-ccbc-42a3-9fc2-25d948ee680d,0 23 | a1cdff70-2b42-48f5-9971-410e9888796a,3 24 | 24139643-6c8d-4dcb-827a-542ce25d7cdb,3 25 | 6c4d9f4f-79c2-4436-9de0-bcfeb3a81b4f,1 26 | 8d97df11-7b3f-46a9-b05f-25d187638537,0 27 | 822a7a20-b65c-4225-b460-528f18468e00,4 28 | 9b19eb00-516d-4df6-9be1-d281aa8628c6,4 29 | 2a20ce5f-21dd-424a-9191-79e4f077f59c,3 30 | 65e11734-3e24-437a-9aba-ac59c0971d6f,3 31 | 6eae204f-796f-4d6d-a32b-da0e894c8b74,3 32 | df6d7f07-d714-4d7f-86b8-b4e18ed09c92,3 33 | 463b48d9-5bde-4609-8775-8fc1b28b61de,1 34 | fceb4fa5-a82f-4683-a017-08abaf715294,0 35 | 8c51ae7b-4516-458c-93b6-ce39d049101e,1 36 | fee73a1c-c636-4064-aad1-e876a252bf80,3 37 | b71de4dd-38e2-48ec-a677-97cf7333d8cb,1 38 | 0b712121-8c14-4f5a-97df-51ad6db9bc82,0 39 | 65022059-0d96-4d78-b27c-5211e8183d52,2 40 | 0c219822-49b8-4c4c-b28c-be7a03f5b913,1 41 | ec9ab74c-a3df-456b-bcd8-09a35319581f,1 42 | b57c2ff5-3c6a-49c1-bf51-74e046002b05,1 43 | 71df0bc7-60d0-4fd3-9c18-3859eccb15f3,2 44 | c747df82-0e9a-4ba1-8220-4f7678ca7a55,1 45 | f20d19c2-5a46-477b-a373-0fcabab2dbda,4 46 | ca77251b-093b-48ef-85d4-e337d8a6336e,5 47 | 4c6da53c-baa1-4b95-b679-0cd1c1c190e5,2 48 | 6aec2974-84b3-4800-a1f7-01b504b7b20f,1 49 | a26af847-4626-4102-aeb4-6bc0dce34d89,0 50 | 99eea3c6-d9d0-4856-8393-ee33bb8bb489,6 51 | c6c9f5d3-b799-4cd2-84ab-57eaaca3987f,4 52 | 71d87d1f-e9d2-4bd1-b8d2-4b071d794f2e,6 53 | d69c16d4-385b-4823-8c81-79f560bac3cc,2 54 | 50baad9f-2fde-4e30-9f91-40819246995d,1 55 | dd54981c-5c4e-4c8e-9feb-85ab880bb879,1 56 | 1fd83edb-2cf2-48c1-9efe-b15fc01e4138,3 57 | 7d1cbec1-d6d2-4f0b-8734-527192f1d509,1 58 | 402cc271-4ae3-4b02-858a-0d2363ffedf0,4 59 | 9260a5e9-d72d-4131-8bac-c5ba518f83a7,3 60 | 0ba7de27-a741-443f-947f-51ffc7a5662c,2 61 | 98f2551d-a382-46c6-9c82-3d466d299d7d,0 62 | 527d5ff9-2e02-4dae-9484-b82844e87f1c,1 63 | 00f61607-9691-43d6-ab00-ba3e09e2195e,2 64 | 28e9f1f2-8dd2-416c-a0f1-101f290e5823,2 65 | 31dd1a17-7f18-445b-9e59-6b7c052adaa6,0 66 | 746c622d-9bbd-449a-bf45-d7d78dda32e8,4 67 | 80f7625b-cc09-4666-9607-1f27127adf97,2 68 | 7da05ea0-9959-49e7-b0c4-a771e64eeaed,2 69 | fadc029a-082c-4bd1-aab2-ea488afa8cc9,0 70 | 69168e02-76e1-4dc1-9410-d5d3e34eadef,3 71 | c296c911-2589-4982-8863-e06909a01080,5 72 | ebffca74-0373-421a-a422-f58406243f1e,3 73 | 9b19eb00-516d-4df6-9be1-d281aa8628c6,1 74 | 1ef4362f-0bcd-4209-b610-0864a24d2fdf,1 75 | 8613da88-0a12-4855-a420-4eea91fe6a32,3 76 | e4894296-8e2e-4759-ae55-9dc6c1339777,2 77 | de9aed72-9598-46d0-b24a-50d2e7f49440,2 78 | 3ac07aec-8c21-46c6-817d-1794b082a62f,2 79 | 3fdf3aad-a740-40c1-938d-95a449126ae5,2 80 | d14c5147-b7c9-4edb-89fb-79632ebc3ef9,12 81 | 31811a76-59ea-494f-8b8d-ec477f7bebca,2 82 | e8be360b-fd9d-4504-96c3-526a02bbe46d,3 83 | b4b34e07-f1c8-4c89-b051-e525ec79c8d1,12 84 | cbfd6f78-2f13-41f2-9a3e-3d7844fbee57,3 85 | 1ad516fd-f452-4af0-b877-b5d4f6b47806,0 86 | 2a4c7de3-a990-43da-93a6-851370ed1c3f,2 87 | 3745161d-c0d1-41a0-9baa-6bb77be2b876,3 88 | caf31ca2-8e7b-40d8-91b9-a8e2a4fe6503,3 89 | 2f56fe44-b08d-4fad-9d06-e027eb237ab0,1 90 | 5ce97ea6-a008-4e8b-97ba-164fdbdf4a8c,2 91 | 54fc80f9-efb9-4266-9cfa-7c7d0660f672,6 92 | 26841c21-d3ed-4541-aa59-907e93cc38d2,22 93 | 45d5c757-2bc3-40be-8c5b-f2ebc3cd9572,1 94 | cd010b9d-266f-4df2-aa50-65eace32d430,3 95 | 1e4a2550-e308-41ac-a247-7299e6e2db31,6 96 | 8e52d0e3-d751-4a02-9f68-45be7161be19,4 97 | d010d149-5b92-4a3f-ab20-ec9ab8bd3391,7 98 | 78437d70-5ea5-4df8-a95b-aa46733519d2,0 99 | bbc39510-f8e5-4fe2-b65b-5dc4c5fe5fd7,1 100 | 7bdc0ac1-0248-4841-8465-5e1cd44860fc,17 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification 2 | 3 | ![Demo](https://i.imgur.com/ynNTrpf.png) 4 | 5 | [[Homepage](https://um-arm-lab.github.io/Efficient-Eng-2-LTL/)] [[Paper](https://arxiv.org/abs/2303.08006)] [[Video](https://drive.google.com/file/d/14Sy5y76YglZ6X3Y3ZZBZZiMGBA9gME9G/view?usp=sharing)] [[Poster](https://drive.google.com/file/d/1j0aZoROb1EKC0oRYYBSwBIx4Xp8ElowN/view?usp=sharing)] 6 | 7 | > The associated repo for paper "Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification". 8 | 9 | ## Repo Structure 10 | 11 | - Root 12 | - datasets 13 | - [drone-planning](https://arxiv.org/abs/1905.12096) 14 | - [clean-up](http://www.roboticsproceedings.org/rss14/p67.html) 15 | - [pick-and-place](http://www.roboticsproceedings.org/rss14/p67.html) 16 | - augmentation 17 | - paraphrase with GPT-3 18 | - run 19 | - train the models 20 | - inference with constrained decoding 21 | 22 | The constrained decoding inference code is based on: [microsoft/semantic_parsing_with_constrained_lm](https://github.com/microsoft/semantic_parsing_with_constrained_lm) 23 | 24 | ## Reproduce the Results 25 | 26 | Following are the instructions for reproducing the result for our model. For baselines, come and check out [this link](https://github.com/UM-ARM-Lab/Efficient-Eng-2-LTL/issues/1). 27 | ### Environment Setup 28 | 29 | Install the dependencies: 30 | 31 | ```bash 32 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 # make sure the version is compatible with your cuda version 33 | pip install transformers datasets 34 | pip install sentencepiece 35 | pip install jsons appdirs blobfile cached-property httpx typer whoosh more_itertools 36 | pip install --upgrade protobuf==3.20.0 37 | ``` 38 | 39 | Download BART-large model: 40 | 41 | ```bash 42 | cd ./run 43 | python ./semantic_parsing_with_constrained_lm/finetune/download_huggingface_lms.py 44 | ``` 45 | 46 | ### Prepare the Dataset 47 | 48 | The processed dataset (with augmentation from LLM) in already included in the repo. This step is only needed if you want to reprocess the dataset. 49 | 50 | To actually process the raw dataset, you can follow the steps below: 51 | 52 | 1. Pre-process: In each of the three dataset folders, run all cells in "preprocess.ipynb" to generate the processed dataset. (the annotation result is included in the notebook). 53 | 2. Augmentation: For each of the three datasets, run all commands in "augment.ipynb" to generate the augmented dataset. Note that this step requires a GPT-3 API key. 54 | 3. Move to training folder: You then need to reformat the dataset and move it to the `run/semantic_parsing_with_constrained_lm/domains/ltl/data` folder. A script will be provided later to help you automate this process. 55 | 56 | ### Train 57 | 58 | In our paper, we use the [BART-large model](https://huggingface.co/facebook/bart-large) because it is efficient to fine-tune on a single GPU. Our proposed method can be easily applied to other potentially stronger language models like [T5-XXL](https://arxiv.org/abs/1910.10683) or [GPT-3](https://arxiv.org/abs/2005.14165). 59 | 60 | ```sh 61 | export PRETRAINED_MODEL_DIR=huggingface_models/bart-large 62 | export TRAINED_MODEL_DIR=trained_models/ 63 | 64 | cd ./run 65 | DOMAIN=TODO # for example, DOMAIN=pick-syn-aug 66 | python -m semantic_parsing_with_constrained_lm.finetune.lm_finetune \ 67 | --config-name semantic_parsing_with_constrained_lm.finetune.configs.emnlp_train_config \ 68 | --exp-names ltl_${DOMAIN}_utterance 69 | ``` 70 | 71 | Here DOMAIN determines which experiment to run. 72 | DOMAIN: {dataset_name}-{experiment_name} 73 | 74 | - dataset_name: {drone, cleanup, pick} 75 | - experiment_name: 76 | - syn-aug: synthetic with augmentation 77 | - syn: synthetic without augmentation 78 | - golden-cross0-split{0,1,2,3,4}: golden dataset with cross-validation 79 | 80 | ### Inference 81 | 82 | ```sh 83 | export PRETRAINED_MODEL_DIR=huggingface_models/bart-large 84 | export TRAINED_MODEL_DIR=trained_models/ 85 | 86 | DOMAIN=TODO 87 | 88 | python -m semantic_parsing_with_constrained_lm.run_exp \ 89 | --config-name semantic_parsing_with_constrained_lm.configs.ltl_config \ 90 | --log-dir logs/ \ 91 | --model Bart \ 92 | --eval-split test-full \ 93 | --exp-names "ltl_Bart_test-full_${DOMAIN}_constrained_utterance_train-0" 94 | ``` 95 | 96 | The domain name is the same as the training step. 97 | 98 | ## Cite 99 | ```bibtex 100 | @article{pan2023data, 101 | title={Data-Efficient Learning of Natural Language to Linear Temporal Logic Translators for Robot Task Specification}, 102 | author={Pan, Jiayi and Chou, Glen and Berenson, Dmitry}, 103 | journal={arXiv preprint arXiv:2303.08006}, 104 | year={2023} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/index/bm25_index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import random 5 | import tempfile 6 | from typing import Callable, Generic, Iterable, List, Sequence, Tuple 7 | 8 | import whoosh.index 9 | from whoosh.fields import STORED, TEXT, SchemaClass 10 | from whoosh.qparser import OrGroup, QueryParser 11 | 12 | from semantic_parsing_with_constrained_lm.datum import DatumSub, FullDatumSub 13 | from semantic_parsing_with_constrained_lm.index import Candidate, DynamicIndex, Query 14 | from semantic_parsing_with_constrained_lm.model import DataRetriever 15 | 16 | 17 | class PromptSchema(SchemaClass): 18 | text = TEXT() 19 | key = STORED() 20 | 21 | 22 | class BM25Index(Generic[Query, Candidate], DynamicIndex[int, Query, Candidate]): 23 | def __init__( 24 | self, get_content: Callable[[Candidate], str], get_query: Callable[[Query], str] 25 | ): 26 | # TODO: custom tokenizer from ClampTokenizer 27 | # TODO: now indexed in a temp dir 28 | tmp_indexer_loc = tempfile.mkdtemp() 29 | self.index = whoosh.index.create_in(tmp_indexer_loc, schema=PromptSchema) 30 | self.get_content = get_content 31 | self.get_query = get_query 32 | 33 | @classmethod 34 | def create( 35 | cls, 36 | candidates: Iterable[Candidate], 37 | get_content: Callable[[Candidate], str], 38 | get_query: Callable[[Query], str], 39 | ) -> "BM25Index": 40 | index = BM25Index(get_content, get_query) 41 | with index.index.writer() as writer: 42 | for i, candidate in enumerate(candidates): 43 | writer.add_document(text=get_content(candidate), key=i) 44 | return index 45 | 46 | def add(self, candidates: Iterable[Candidate]): 47 | n = self.index.doc_count() 48 | with self.index.writer() as writer: 49 | for i, candidate in enumerate(candidates): 50 | writer.add_document( 51 | text=self.get_content(candidate), key=n + i 52 | ) # auto-increment key 53 | 54 | def search(self, query: Query, top_k: int = 10) -> List[Tuple[int, float]]: 55 | with self.index.searcher() as searcher: 56 | query_parser = QueryParser("text", schema=searcher.schema, group=OrGroup) 57 | q = query_parser.parse(self.get_query(query)) 58 | results = searcher.search(q, limit=top_k) 59 | return [(result["key"], result.score) for result in results] 60 | 61 | 62 | class BM25Retriever(DataRetriever[FullDatumSub, DatumSub]): 63 | def __init__( 64 | self, 65 | train_data: Sequence[FullDatumSub], 66 | top_k: int = 20, 67 | best_first: bool = True, 68 | seed: int = 12345, 69 | ): 70 | self.index: BM25Index[DatumSub, FullDatumSub] = BM25Index.create( 71 | train_data, 72 | get_content=lambda c: c.natural, # type: ignore 73 | get_query=lambda q: q.natural, # type: ignore 74 | ) 75 | self.data: List[FullDatumSub] = list(train_data) 76 | self.top_k = top_k 77 | self.best_first = best_first 78 | self.prng = random.Random( 79 | seed 80 | ) # a random number generator to ensure deterministic behavior 81 | 82 | def augment_with_random_samples( 83 | self, data: Sequence[FullDatumSub], retrieved_keys: Sequence[int] 84 | ) -> Sequence[FullDatumSub]: 85 | 86 | if len(retrieved_keys) < self.top_k: 87 | print( 88 | f"Could not retrieve {self.top_k} examples, got only {len(retrieved_keys)}" 89 | ) 90 | keys_to_sample = sorted(set(range(len(data))).difference(retrieved_keys)) 91 | sampled_keys = self.prng.sample( 92 | keys_to_sample, 93 | k=min(self.top_k - len(retrieved_keys), len(keys_to_sample)), 94 | ) 95 | augmented_keys = list(retrieved_keys) + list(sampled_keys) 96 | print(f"Added samples to make it of size {len(augmented_keys)}") 97 | items = [data[k] for k in augmented_keys[: self.top_k]] 98 | else: 99 | items = [data[k] for k in retrieved_keys[: self.top_k]] 100 | 101 | return items if self.best_first else list(reversed(items)) 102 | 103 | def add(self, data: Sequence[FullDatumSub]): 104 | self.index.add(data) 105 | self.data.extend(data) 106 | 107 | async def __call__(self, test_datum: DatumSub) -> Sequence[FullDatumSub]: 108 | results = self.augment_with_random_samples( 109 | data=self.data, 110 | retrieved_keys=[ 111 | key for key, _ in self.index.search(test_datum, top_k=self.top_k) 112 | ], # score discarded 113 | ) 114 | return results 115 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/configs/lib/benchclamp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | from pathlib import Path 6 | 7 | from semantic_parsing_with_constrained_lm.earley.cfg import load_grammar_from_directory 8 | from semantic_parsing_with_constrained_lm.datum import DatumSub 9 | from semantic_parsing_with_constrained_lm.decoding.earley_partial_parse import ( 10 | GrammarTokenizerInfo, 11 | UTF8EarleyPartialParse, 12 | ) 13 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import StartsWithSpacePartialParse 14 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import ( 15 | BenchClampDataset, 16 | BenchClampDatasetConfig, 17 | ) 18 | from semantic_parsing_with_constrained_lm.domains.lispress_v2.grammar import ( 19 | create_partial_parse_builder as create_partial_parse_builder_lispress_v2, 20 | ) 21 | from semantic_parsing_with_constrained_lm.domains.mtop.grammar import ( 22 | create_partial_parse_builder as create_partial_parse_builder_mtop, 23 | ) 24 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.grammar import ( 25 | load_base_grammar, 26 | preprocessed_grammar_for_schema, 27 | ) 28 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.schema import load_schemas 29 | from semantic_parsing_with_constrained_lm.model import PartialParseBuilder 30 | from semantic_parsing_with_constrained_lm.paths import BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE 31 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 32 | 33 | TEST_SUITE_PATH = Path("/mnt/my_input/test-suite-sql-eval") 34 | TEST_SUITE_DATABASE_PATH = Path("/mnt/my_input/test-suite-sql-eval/database/") 35 | SPIDER_DATABASE_PATH = Path("/mnt/my_input/Spider/database/") 36 | SPIDER_TABLES_FILE = Path("/mnt/my_input/Spider/tables.json") 37 | COSQL_DATABASE_PATH = Path("/mnt/my_input/CoSQL/database/") 38 | COSQL_TABLES_FILE = Path("/mnt/my_input/CoSQL/tables.json") 39 | 40 | 41 | def create_partial_parse_builder( 42 | constrained: bool, data_config: BenchClampDatasetConfig, tokenizer: ClampTokenizer 43 | ) -> PartialParseBuilder[DatumSub]: 44 | if constrained: 45 | domain_str = data_config.domain if data_config.domain is not None else "" 46 | if data_config.dataset_name in [ 47 | BenchClampDataset.Spider.value, 48 | BenchClampDataset.CoSQL.value, 49 | ]: 50 | print("Loading database schemas ...") 51 | if data_config.dataset_name == BenchClampDataset.Spider.value: 52 | schemas = load_schemas( 53 | schemas_path=SPIDER_TABLES_FILE, 54 | db_path=SPIDER_DATABASE_PATH, 55 | ) 56 | else: 57 | schemas = load_schemas( 58 | schemas_path=COSQL_TABLES_FILE, db_path=COSQL_DATABASE_PATH 59 | ) 60 | print("Done") 61 | 62 | base_grammar = load_base_grammar() 63 | pre_grammars = { 64 | name: preprocessed_grammar_for_schema(db, base_grammar) 65 | for name, db in schemas.items() 66 | } 67 | grammar_tok_info = { 68 | name: GrammarTokenizerInfo.create(tokenizer, preprocessed_grammar, True) 69 | for name, preprocessed_grammar in pre_grammars.items() 70 | } 71 | partial_parse_builder = lambda datum: UTF8EarleyPartialParse.initial( 72 | grammar_tok_info[datum.schema_name], datum.natural # type: ignore 73 | ) 74 | 75 | elif data_config.dataset_name in ( 76 | BenchClampDataset.CalFlowV2.value, 77 | BenchClampDataset.TreeDST.value, 78 | ): 79 | partial_parse_builder = create_partial_parse_builder_lispress_v2( 80 | load_grammar_from_directory( 81 | os.path.join( 82 | BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE, 83 | data_config.dataset_name, 84 | domain_str, 85 | ) 86 | ), 87 | tokenizer, 88 | ) 89 | elif data_config.dataset_name == BenchClampDataset.MTOP.value: 90 | partial_parse_builder = create_partial_parse_builder_mtop( 91 | load_grammar_from_directory( 92 | os.path.join( 93 | BENCH_CLAMP_GRAMMAR_DATA_DIR_AZURE, 94 | data_config.dataset_name, 95 | domain_str, 96 | ) 97 | ), 98 | tokenizer, 99 | ) 100 | else: 101 | raise ValueError(f"{data_config.dataset_name} not supported") 102 | else: 103 | partial_parse = StartsWithSpacePartialParse(tokenizer) 104 | partial_parse_builder = lambda _: partial_parse 105 | return partial_parse_builder 106 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/train_model_setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import abc 5 | from abc import abstractmethod 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import torch 11 | from transformers import ( 12 | BartForConditionalGeneration, 13 | GPT2LMHeadModel, 14 | PreTrainedModel, 15 | T5ForConditionalGeneration, 16 | ) 17 | 18 | from semantic_parsing_with_constrained_lm.lm import Seq2SeqSettings, Surround 19 | from semantic_parsing_with_constrained_lm.tokenization import ( 20 | ClampTokenizer, 21 | GPT2ClampTokenizer, 22 | T5ClampTokenizer, 23 | ) 24 | 25 | 26 | class TrainedModelNotFoundError(FileNotFoundError): 27 | pass 28 | 29 | 30 | @dataclass # type: ignore 31 | class ClampModelConfig(abc.ABC): 32 | model_id: str 33 | model_loc: Path 34 | device_map: Optional[Dict[int, List[int]]] = None 35 | 36 | @abstractmethod 37 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]: 38 | pass 39 | 40 | def maybe_parallelize(self, model: PreTrainedModel) -> None: 41 | if torch.cuda.is_available(): 42 | if self.device_map is not None: 43 | print(f"Parallelizing model with {self.device_map}") 44 | model.parallelize(self.device_map) 45 | else: 46 | print("Entire model to GPU 0") 47 | model.to(torch.device("cuda:0")) 48 | else: 49 | model.to(torch.device("cpu")) 50 | 51 | 52 | class BartModelConfig(ClampModelConfig): 53 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]: 54 | if not self.model_loc.exists(): 55 | raise TrainedModelNotFoundError( 56 | f"Model files not found in {self.model_loc}" 57 | ) 58 | model = BartForConditionalGeneration.from_pretrained(self.model_loc) 59 | tokenizer = GPT2ClampTokenizer.from_pretrained(str(self.model_loc)) 60 | seq2seq_settings = Seq2SeqSettings( 61 | input_surround=Surround(bos=[0], eos=[2], starts_with_space=True), 62 | output_surround=Surround(bos=[0], eos=[2], starts_with_space=True), 63 | decoder_start_token_id=2, 64 | ) 65 | self.maybe_parallelize(model) 66 | model.eval() 67 | return model, tokenizer, seq2seq_settings 68 | 69 | 70 | class T5ModelConfig(ClampModelConfig): 71 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]: 72 | if not self.model_loc.exists(): 73 | raise TrainedModelNotFoundError( 74 | f"Model files not found in {self.model_loc}" 75 | ) 76 | print(f"Loading model from {self.model_loc}") 77 | model = T5ForConditionalGeneration.from_pretrained(self.model_loc) 78 | print("Done") 79 | tokenizer = T5ClampTokenizer.from_pretrained(str(self.model_loc)) 80 | seq2seq_settings = Seq2SeqSettings( 81 | input_surround=Surround(bos=[], eos=[1], starts_with_space=True), 82 | output_surround=Surround(bos=[], eos=[1], starts_with_space=True), 83 | decoder_start_token_id=tokenizer.pad_token_id, 84 | ) 85 | self.maybe_parallelize(model) 86 | model.eval() 87 | return model, tokenizer, seq2seq_settings 88 | 89 | 90 | class CodeT5ModelConfig(ClampModelConfig): 91 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]: 92 | if not self.model_loc.exists(): 93 | raise TrainedModelNotFoundError( 94 | f"Model files not found in {self.model_loc}" 95 | ) 96 | model = T5ForConditionalGeneration.from_pretrained(self.model_loc) 97 | tokenizer = GPT2ClampTokenizer.from_pretrained(str(self.model_loc)) 98 | seq2seq_settings = Seq2SeqSettings( 99 | input_surround=Surround(bos=[1], eos=[2], starts_with_space=True), 100 | output_surround=Surround(bos=[1], eos=[2], starts_with_space=True), 101 | decoder_start_token_id=0, 102 | ) 103 | self.maybe_parallelize(model) 104 | model.eval() 105 | return model, tokenizer, seq2seq_settings 106 | 107 | 108 | class GPT2ModelConfig(ClampModelConfig): 109 | def setup_model(self) -> Tuple[PreTrainedModel, ClampTokenizer, Seq2SeqSettings]: 110 | if not self.model_loc.exists(): 111 | raise TrainedModelNotFoundError( 112 | f"Model files not found in {self.model_loc}" 113 | ) 114 | model = GPT2LMHeadModel.from_pretrained(self.model_loc) 115 | tokenizer = GPT2ClampTokenizer.from_pretrained(str(self.model_loc)) 116 | seq2seq_settings = Seq2SeqSettings( 117 | input_surround=Surround( 118 | bos=[20490, 25], eos=[198], starts_with_space=True 119 | ), # bos: "Human:", eos: "\n" 120 | output_surround=Surround( 121 | bos=[34556, 25], eos=[198], starts_with_space=True 122 | ), # bos: "Computer:", eos: "\n" 123 | decoder_start_token_id=None, 124 | ) 125 | self.maybe_parallelize(model) 126 | model.eval() 127 | return model, tokenizer, seq2seq_settings 128 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/lispress_v2/grammar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import functools 5 | import re 6 | from typing import Iterable, List, Optional, Set 7 | 8 | from semantic_parsing_with_constrained_lm.domains import dfa_grammar_utils 9 | from semantic_parsing_with_constrained_lm.domains.lispress_v2.lispress_exp import ( 10 | BooleanExpr, 11 | CallExpr, 12 | DialogueV2, 13 | LambdaExpr, 14 | LetExpr, 15 | LispressExpr, 16 | LongExpr, 17 | NumberExpr, 18 | ReferenceExpr, 19 | StringExpr, 20 | TypeName, 21 | parse_fully_typed_lispress_v2, 22 | ) 23 | 24 | 25 | def get_nt_from_type(type_name: TypeName) -> str: 26 | segments = ( 27 | str(type_name) 28 | .replace(" ", " SP ") 29 | .replace("(", " LP ") 30 | .replace(")", " RP ") 31 | .replace(".", " DOT ") 32 | .split() 33 | ) 34 | return "_".join(segments + ["NT"]) 35 | 36 | 37 | def extract_grammar_rules(lispress_expr: LispressExpr) -> Set[str]: 38 | lhs = get_nt_from_type(lispress_expr.type) # type: ignore 39 | rules = set() 40 | 41 | if isinstance(lispress_expr, (NumberExpr, LongExpr, StringExpr, BooleanExpr)): 42 | pass 43 | elif isinstance(lispress_expr, ReferenceExpr): 44 | rules.add(f'{lhs} -> "{lispress_expr.var_name}"') 45 | 46 | elif isinstance(lispress_expr, LambdaExpr): 47 | rhs_items = [ 48 | f'"(lambda (^{str(lispress_expr.var_type)} {lispress_expr.var_name}) "', 49 | get_nt_from_type(lispress_expr.main_expr.type), # type: ignore 50 | '")"', 51 | ] 52 | rhs = " ".join(rhs_items) 53 | rules.add(f"{lhs} -> {rhs}") 54 | rules.update(extract_grammar_rules(lispress_expr.main_expr)) 55 | 56 | elif isinstance(lispress_expr, LetExpr): 57 | var_name_expr_nts = [] 58 | for var_name, var_expr in lispress_expr.var_assignments: 59 | var_name_expr_nts.extend([f'"{var_name}"', get_nt_from_type(var_expr.type)]) # type: ignore 60 | rules.update(extract_grammar_rules(var_expr)) 61 | var_name_expr_nts_str = ' " " '.join(var_name_expr_nts) 62 | rhs = f'"(let (" {var_name_expr_nts_str} ") " {get_nt_from_type(lispress_expr.main_expr.type)} ")"' # type: ignore 63 | rules.add(f"{lhs} -> {rhs}") 64 | rules.update(extract_grammar_rules(lispress_expr.main_expr)) 65 | 66 | elif isinstance(lispress_expr, CallExpr): 67 | rhs_items: List[str] = [] 68 | if lispress_expr.instantiation_type is not None: 69 | rhs_items.append(f'"^{lispress_expr.instantiation_type} "?') 70 | 71 | rhs_items.append(f'"{lispress_expr.name}"') 72 | 73 | for k, v in lispress_expr.args: 74 | rhs_items.extend([f'" :{k}"?', '" "', get_nt_from_type(v.type)]) # type: ignore 75 | rules.update(extract_grammar_rules(v)) 76 | 77 | rhs = " ".join(rhs_items) 78 | rules.add(f'{lhs} -> "(" {rhs} ")"') 79 | 80 | return rules 81 | 82 | 83 | def extract_grammar( 84 | dataflow_dialogues: Iterable[DialogueV2], 85 | whitelisted_dialogue_ids: Optional[Set[str]] = None, 86 | ) -> Set[str]: 87 | grammar_rules = set() 88 | for dataflow_dialogue in dataflow_dialogues: 89 | if ( 90 | whitelisted_dialogue_ids is not None 91 | and dataflow_dialogue.dialogue_id not in whitelisted_dialogue_ids 92 | ): 93 | continue 94 | 95 | for turn in dataflow_dialogue.turns: 96 | lispress_expr = parse_fully_typed_lispress_v2(turn.fully_typed_lispress) 97 | grammar_rules.update(extract_grammar_rules(lispress_expr)) 98 | root_type_nt = get_nt_from_type(lispress_expr.type) # type: ignore 99 | grammar_rules.add(f'start -> " " {root_type_nt}') 100 | 101 | # find string literals 102 | for match in re.finditer(r'Path\.apply "([^"]*)"', turn.lispress): 103 | start = match.start(1) 104 | end = match.end(1) 105 | item = turn.lispress[start:end] 106 | # We use `repr` because the .cfg parser uses `ast.literal_eval` 107 | # to parse the strings, since that will handle backslash escape 108 | # sequences. Without `repr` the resulting grammar will have one 109 | # level of escaping removed. 110 | grammar_rules.add(f"path_literal -> {repr(item)}") 111 | grammar_rules.update( 112 | [ 113 | 'Boolean_NT -> "true"', 114 | 'Boolean_NT -> "false"', 115 | r'String_NT -> "\"" (String_NT_content | path_literal | "output" | "place" | "start") "\""', 116 | # Lispress V2 string literals are JSON string literals, so we follow this grammar: 117 | # https://datatracker.ietf.org/doc/html/rfc8259#section-7 118 | r'String_NT_content -> ([[\u0020-\U0010FFFF]--[\u0022\u005C]] | "\\" (["\u005C/bfnrt] | u[0-9A-Fa-f]{4}))*', 119 | 'Number_NT -> ("0" | [1-9][0-9]*) ("." [0-9]+)?', 120 | 'Long_NT -> ("0" | [1-9][0-9]*) "L"', 121 | ] 122 | ) 123 | return grammar_rules 124 | 125 | 126 | create_partial_parse_builder = functools.partial( 127 | dfa_grammar_utils.create_partial_parse_builder, 128 | utterance_nonterm_name="String_NT_content", 129 | ) 130 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/sql_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import dataclasses 5 | import subprocess 6 | from collections import defaultdict 7 | from dataclasses import dataclass 8 | from typing import Dict, List, Optional, Sequence, Tuple 9 | 10 | from semantic_parsing_with_constrained_lm.datum import FullDatumSub 11 | from semantic_parsing_with_constrained_lm.domains.sql.sql_datum import SqlDatum 12 | from semantic_parsing_with_constrained_lm.eval import Metric 13 | 14 | 15 | @dataclass 16 | class SQLTestSuiteMatch(Metric[Sequence[str], FullDatumSub]): 17 | """ 18 | Metric to evaluate SQL predictions. Uses the test-suite available here: 19 | https://github.com/taoyds/test-suite-sql-eval. To use this metric, clone this repo to a local 20 | directory and set `test_suite_path` to that directory. 21 | """ 22 | 23 | db_path: str 24 | test_suite_path: str 25 | table_file: str 26 | log_dir: str 27 | schema_map: Dict[str, str] = dataclasses.field(init=False) 28 | predictions_map: Dict[Tuple[str, int], str] = dataclasses.field(init=False) 29 | gold_map: Dict[Tuple[str, int], str] = dataclasses.field(init=False) 30 | dialogue_to_turn_indices_map: Dict[str, List[int]] = dataclasses.field(init=False) 31 | 32 | def __post_init__(self): 33 | self.reset() 34 | 35 | def _is_correct(self, pred: str, target: SqlDatum) -> bool: 36 | """Can be overridden by child classes.""" 37 | return pred == target.canonical 38 | 39 | def update( 40 | self, preds: Sequence[str], target: SqlDatum 41 | ) -> Dict[str, Optional[str]]: 42 | schema_name = target.schema_name 43 | self.schema_map[target.dialogue_id] = schema_name # type: ignore 44 | self.predictions_map[(target.dialogue_id, target.turn_part_index)] = ( # type: ignore 45 | preds[0] if len(preds) > 0 else "dummy" 46 | ) 47 | self.gold_map[(target.dialogue_id, target.turn_part_index)] = target.canonical # type: ignore 48 | self.dialogue_to_turn_indices_map[target.dialogue_id].append( # type: ignore 49 | target.turn_part_index # type: ignore 50 | ) 51 | return {} 52 | 53 | def compute(self, gold_file=None, pred_file=None) -> Dict[str, float]: 54 | # Run test suite using subprocess 55 | is_interaction = any( 56 | [ 57 | len(turn_indices) > 1 58 | for _, turn_indices in self.dialogue_to_turn_indices_map.items() 59 | ] 60 | ) 61 | if gold_file is None and pred_file is None: 62 | gold_file = self.log_dir + "/gold.txt" 63 | pred_file = self.log_dir + "/pred.txt" 64 | with open(gold_file, "w") as fp_gold, open(pred_file, "w") as fp_pred: 65 | for dialogue_id in self.dialogue_to_turn_indices_map: 66 | for turn_index in self.dialogue_to_turn_indices_map[dialogue_id]: 67 | gold = self.gold_map[(dialogue_id, turn_index)] 68 | if gold.count(")") == 1 and gold.count("(") == 0: 69 | gold = gold.replace(")", "") 70 | if "faculty_participates_in" in gold: 71 | gold = gold.replace( 72 | "faculty_participates_in", "Faculty_participates_in" 73 | ) 74 | fp_gold.write( 75 | gold.replace(" . ", ".") 76 | + "\t" 77 | + self.schema_map[dialogue_id] 78 | + "\n" 79 | ) 80 | fp_pred.write( 81 | self.predictions_map[(dialogue_id, turn_index)].replace( 82 | " . ", "." 83 | ) 84 | + "\n" 85 | ) 86 | 87 | if is_interaction: 88 | fp_gold.write("\n") 89 | fp_pred.write("\n") 90 | 91 | process = subprocess.run( 92 | [ 93 | "python3", 94 | "evaluation.py", 95 | "--gold", 96 | gold_file, 97 | "--pred", 98 | pred_file, 99 | "--db", 100 | self.db_path, 101 | "--table", 102 | self.table_file, 103 | "--etype", 104 | "all", 105 | ], 106 | cwd=self.test_suite_path, 107 | capture_output=True, 108 | text=True, 109 | check=True, 110 | ) 111 | print("stdout:", process.stdout) 112 | print("stderr:", process.stderr) 113 | execution_acc = 0.0 114 | for line in process.stdout.split("\n"): 115 | if line.startswith("execution"): 116 | execution_acc = float(line.split()[5].strip()) 117 | break 118 | result = {"execution_acc": execution_acc} 119 | print(result) 120 | return result 121 | 122 | def reset(self) -> None: 123 | self.predictions_map = {} 124 | self.gold_map = {} 125 | self.schema_map = {} 126 | self.dialogue_to_turn_indices_map = defaultdict(list) 127 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/fit_max_steps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import collections 5 | import itertools 6 | import random 7 | from dataclasses import dataclass 8 | from typing import Dict, Iterator, List, Tuple 9 | 10 | import numpy as np 11 | 12 | 13 | @dataclass(frozen=True) 14 | class Quantiles: 15 | # THe quantile used to find the intercept. 16 | # In effect, we sort the list of lengths, and take 17 | # `lengths[int(len(lengths) * intercept_quantile)]` 18 | # to use as the intercept. 19 | intercept_quantile: float 20 | # The quantile used for the slope. 21 | # For each datum, we compute the slope that would be needed so that the 22 | # predicted length matches the gold length. 23 | slope_quantile: float 24 | 25 | 26 | @dataclass 27 | class Result: 28 | # Number of test instances that had predicted lengths smaller than gold lengths. 29 | num_unreachable: int 30 | # predicted length - gold length 31 | surplus_steps: List[int] 32 | 33 | 34 | def fit(pairs: List[Tuple[int, int]]) -> Iterator[Tuple[Quantiles, float, float]]: 35 | # 0, 0.01, 0.02, ..., 0.99, 1.00 36 | intercept_quantiles = np.linspace(0, 1, 101) 37 | # 1, 0.99, 0.98, ..., 0.90 38 | slope_quantiles = np.linspace(1, 0.9, 11) 39 | 40 | intercept_by_quantile = dict( 41 | zip( 42 | intercept_quantiles, np.quantile([o for _, o in pairs], intercept_quantiles) 43 | ) 44 | ) 45 | intercept_by_quantile[-1] = 0.0 46 | 47 | for intercept_quantile, intercept in intercept_by_quantile.items(): 48 | slopes = [(o - intercept) / i for i, o in pairs] 49 | max_slopes = np.quantile(slopes, slope_quantiles) 50 | for slope_quantile, slope in zip(slope_quantiles, max_slopes): 51 | yield Quantiles(intercept_quantile, slope_quantile), intercept, slope 52 | 53 | 54 | def cross_validation_fit( 55 | pairs: List[Tuple[int, int]], num_splits: int 56 | ) -> Dict[Quantiles, List[Result]]: 57 | all_results: Dict[Quantiles, List[Result]] = collections.defaultdict(list) 58 | 59 | # Perform n-fold cross-validation 60 | test_set_size = len(pairs) / num_splits # This is a real number 61 | random.Random(0).shuffle(pairs) 62 | for test_start, test_end in zip( 63 | np.arange(0, len(pairs), test_set_size), 64 | np.arange(test_set_size, len(pairs), test_set_size), 65 | ): 66 | # Round down first so that we get an integer size for the test set 67 | test_start = int(test_start) 68 | test_end = int(test_end) 69 | 70 | train_data = pairs[:test_start] + pairs[test_end:] 71 | test_data = pairs[test_start:test_end] 72 | 73 | for quantiles, intercept, slope in fit(train_data): 74 | predicted = [int(i * slope + intercept) for i, _ in test_data] 75 | num_missed = sum(p < o for p, (_, o) in zip(predicted, test_data)) 76 | surplus_steps = [p - o for p, (_, o) in zip(predicted, test_data)] 77 | all_results[quantiles].append(Result(num_missed, surplus_steps)) 78 | 79 | return all_results 80 | 81 | 82 | def filter_fit( 83 | all_results: Dict[Quantiles, List[Result]], max_unreachable: int 84 | ) -> List[Tuple[Quantiles, int, float]]: 85 | filtered_results: List[Tuple[Quantiles, int, float]] = [] 86 | for q, results in all_results.items(): 87 | total_unreachable = sum(r.num_unreachable for r in results) 88 | if total_unreachable <= max_unreachable: 89 | mean_surplus_steps = np.average( 90 | list(itertools.chain.from_iterable(r.surplus_steps for r in results)) 91 | ) 92 | filtered_results.append((q, total_unreachable, mean_surplus_steps)) 93 | filtered_results.sort(key=lambda x: x[2]) 94 | 95 | return filtered_results 96 | 97 | 98 | def compute_and_print_fit( 99 | pairs: List[Tuple[int, int]], num_splits: int, max_unreachable: int 100 | ) -> Tuple[float, float]: 101 | all_results: Dict[Quantiles, List[Result]] = cross_validation_fit(pairs, num_splits) 102 | 103 | filtered_results: List[Tuple[Quantiles, int, float]] = filter_fit( 104 | all_results, max_unreachable 105 | ) 106 | 107 | print(f"Best params for total unreachable < {max_unreachable}:") 108 | for q, total_unreachable, mean_surplus_steps in filtered_results[:5]: 109 | print( 110 | f"Intercept quantile {q.intercept_quantile:.2f}, " 111 | f"slope quantile {q.slope_quantile:.2f}: " 112 | f"num unreachable = {total_unreachable}, " 113 | f"mean surplus steps = {mean_surplus_steps:.3g}" 114 | ) 115 | 116 | min_surplus_steps = filtered_results[0][2] 117 | all_params_with_min_surplus_steps = [ 118 | q 119 | for q, _, mean_surplus_steps in filtered_results 120 | if mean_surplus_steps == min_surplus_steps 121 | ] 122 | best_params = min( 123 | all_params_with_min_surplus_steps, 124 | key=lambda q: (q.intercept_quantile, q.slope_quantile), 125 | ) 126 | 127 | if best_params.intercept_quantile == -1: 128 | intercept = 0 129 | else: 130 | intercept = np.quantile([o for _, o in pairs], best_params.intercept_quantile) 131 | slope = np.quantile( 132 | [(o - intercept) / i for i, o in pairs], best_params.slope_quantile 133 | ) 134 | 135 | print() 136 | print(f"Final results: intercept = {intercept}, slope = {slope}") 137 | return (intercept, slope) 138 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/rule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import json 5 | from abc import ABC 6 | from dataclasses import dataclass, replace 7 | from typing import Dict, Set, Tuple, cast 8 | 9 | from semantic_parsing_with_constrained_lm.scfg.parser.token import ( 10 | EmptyToken, 11 | NonterminalToken, 12 | OptionableSCFGToken, 13 | SCFGToken, 14 | TerminalToken, 15 | ) 16 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Expansion 17 | 18 | MAYBE_PREFIX = "maybe__" 19 | 20 | 21 | @dataclass(frozen=True) 22 | class Rule(ABC): 23 | lhs: str 24 | 25 | 26 | @dataclass(frozen=True) 27 | class PlanRule(Rule): 28 | lhs: str 29 | rhs: Expansion 30 | 31 | 32 | @dataclass(frozen=True) 33 | class SyncRule(Rule): 34 | lhs: str 35 | utterance_rhss: Tuple[Expansion, ...] 36 | plan_rhs: Expansion 37 | 38 | 39 | @dataclass(frozen=True) 40 | class UtteranceRule(Rule): 41 | lhs: str 42 | utterance_rhss: Tuple[Expansion, ...] 43 | 44 | 45 | # helpers 46 | def term(s: str, optional=False) -> TerminalToken: 47 | return TerminalToken(underlying=json.dumps(s), optional=optional) 48 | 49 | 50 | def nonterm(s: str, optional=False) -> NonterminalToken: 51 | return NonterminalToken(underlying=s, optional=optional) 52 | 53 | 54 | def mirrored_rule(lhs: str, rhs: Expansion) -> SyncRule: 55 | """Creates a SyncRule with identical utterance and plan expansions.""" 56 | return SyncRule(lhs=lhs, utterance_rhss=(rhs,), plan_rhs=rhs) 57 | 58 | 59 | def expand_optionals(rule: Rule) -> Set[Rule]: 60 | """ 61 | For each optional token `t` in rule, creates a fresh NT that expands to `t` or epsilon. 62 | """ 63 | 64 | def clean(s: str) -> str: 65 | return "".join(c if c.isidentifier() else f"_chr{ord(c)}_" for c in s) 66 | 67 | def mk_optional_nt(s: OptionableSCFGToken) -> NonterminalToken: 68 | nt_or_t = "nt" if isinstance(s, NonterminalToken) else "t" 69 | return nonterm(f"{MAYBE_PREFIX}_{nt_or_t}_{clean(s.render())}") 70 | 71 | utterance_rhss = ( 72 | rule.utterance_rhss if isinstance(rule, (SyncRule, UtteranceRule)) else () 73 | ) 74 | plan_rhss = ( 75 | (rule.plan_rhs,) 76 | if isinstance(rule, SyncRule) 77 | else (rule.rhs,) 78 | if isinstance(rule, PlanRule) 79 | else () 80 | ) 81 | all_utterance_optionals: Set[OptionableSCFGToken] = { 82 | s 83 | for rhs in utterance_rhss 84 | for s in rhs 85 | if isinstance(s, OptionableSCFGToken) and s.optional 86 | } 87 | all_plan_optionals: Set[OptionableSCFGToken] = { 88 | s 89 | for rhs in plan_rhss 90 | for s in rhs 91 | if isinstance(s, OptionableSCFGToken) and s.optional 92 | } 93 | sync_optionals: Dict[SCFGToken, SCFGToken] = { 94 | s: mk_optional_nt(s) 95 | for s in all_utterance_optionals.intersection(all_plan_optionals) 96 | } 97 | utterance_only_optionals: Dict[SCFGToken, SCFGToken] = { 98 | s: mk_optional_nt(s) 99 | for s in all_utterance_optionals.difference(all_plan_optionals) 100 | } 101 | plan_only_optionals: Dict[SCFGToken, SCFGToken] = { 102 | s: mk_optional_nt(s) 103 | for s in all_plan_optionals.difference(all_utterance_optionals) 104 | } 105 | 106 | all_optionals: Dict[SCFGToken, SCFGToken] = { 107 | **utterance_only_optionals, 108 | **plan_only_optionals, 109 | **sync_optionals, 110 | } 111 | new_sync_rules: Set[Rule] = { 112 | r 113 | for s, nt in sync_optionals.items() 114 | for non_opt_s in [(replace(s, optional=False),)] 115 | for r in [ 116 | mirrored_rule(nt.value, non_opt_s), 117 | mirrored_rule(nt.value, (EmptyToken(),)), 118 | ] 119 | } 120 | new_utt_rules: Set[Rule] = { 121 | r 122 | for s, nt in utterance_only_optionals.items() 123 | for non_opt_s in [(replace(s, optional=False),)] 124 | for r in [ 125 | UtteranceRule(lhs=nt.value, utterance_rhss=(non_opt_s,)), 126 | UtteranceRule(lhs=nt.value, utterance_rhss=((EmptyToken(),),)), 127 | ] 128 | } 129 | new_plan_rules: Set[Rule] = { 130 | r 131 | for s, nt in plan_only_optionals.items() 132 | for non_opt_s in [(replace(s, optional=False),)] 133 | for r in [ 134 | PlanRule(lhs=nt.value, rhs=non_opt_s), 135 | PlanRule(lhs=nt.value, rhs=(EmptyToken(),)), 136 | ] 137 | } 138 | new_maybe_rules: Set[Rule] = new_utt_rules | new_plan_rules | new_sync_rules 139 | 140 | def transform_rhs(rhs: Expansion) -> Expansion: 141 | return tuple(all_optionals.get(t, t) for t in rhs) 142 | 143 | new_main_rule: Rule 144 | if isinstance(rule, SyncRule): 145 | new_main_rule = SyncRule( 146 | lhs=rule.lhs, 147 | utterance_rhss=tuple(transform_rhs(rhs) for rhs in rule.utterance_rhss), 148 | plan_rhs=transform_rhs(rule.plan_rhs), 149 | ) 150 | elif isinstance(rule, UtteranceRule): 151 | new_main_rule = UtteranceRule( 152 | lhs=rule.lhs, 153 | utterance_rhss=tuple(transform_rhs(rhs) for rhs in rule.utterance_rhss), 154 | ) 155 | else: 156 | assert isinstance(rule, PlanRule), rule 157 | new_main_rule = PlanRule(lhs=rule.lhs, rhs=transform_rhs(rule.rhs)) 158 | # pyright can't figure out that `new_main_rule` is a Rule 159 | return {cast(Rule, new_main_rule)} | new_maybe_rules 160 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/create_benchclamp_splits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import random 5 | from collections import defaultdict 6 | from pathlib import Path 7 | from typing import Callable, Dict, List, Optional 8 | 9 | import torch 10 | from dataflow.core.io_utils import save_jsonl_file 11 | 12 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum, Datum, FullDatum 13 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse 14 | from semantic_parsing_with_constrained_lm.tokenization import GPT2ClampTokenizer 15 | 16 | 17 | def can_force_decode( 18 | clamp_tokenizer: GPT2ClampTokenizer, 19 | partial_parse_builder: Callable[[Datum], PartialParse], 20 | datum: FullDatum, 21 | ) -> bool: 22 | token_ids = clamp_tokenizer.encode(" " + datum.canonical) 23 | partial_parse = partial_parse_builder(datum) # type: ignore 24 | for i, token_id in enumerate(token_ids): 25 | next_tokens, can_end = partial_parse.allowed_next( 26 | torch.tensor([token_id]), top_k=1 27 | ) 28 | if token_id not in next_tokens: # type: ignore 29 | print() 30 | print(f"Input: {repr(datum.natural)}") 31 | print(f"Output: {repr(datum.canonical)}") 32 | print(f"Prefix: {repr(clamp_tokenizer.decode(token_ids[:i]))}") 33 | print(f"Rejected: {clamp_tokenizer.decode([token_id])}") 34 | return False 35 | partial_parse = partial_parse.append(token_id) 36 | next_tokens, can_end = partial_parse.allowed_next(torch.tensor([0]), top_k=1) 37 | if not can_end: 38 | print() 39 | print(f"Input: {repr(datum.natural)}") 40 | print(f"Output: {repr(datum.canonical)}") 41 | print("Ending not allowed here") 42 | return False 43 | 44 | return True 45 | 46 | 47 | def gather_subset(data: List[BenchClampDatum], size: int) -> List[BenchClampDatum]: 48 | dialogue_id_to_turn_count: Dict[str, int] = defaultdict(int) 49 | for datum in data: 50 | dialogue_id_to_turn_count[datum.dialogue_id] += 1 # type: ignore 51 | 52 | dialogue_ids = sorted(dialogue_id_to_turn_count.keys()) 53 | random.shuffle(dialogue_ids) 54 | selected_dialogue_ids = [] 55 | num_turns_covered = 0 56 | for dialogue_id in dialogue_ids: 57 | selected_dialogue_ids.append(dialogue_id) 58 | num_turns_covered += dialogue_id_to_turn_count[dialogue_id] 59 | if num_turns_covered >= size: 60 | break 61 | 62 | if num_turns_covered < size: 63 | print(f"Not enough data to create subset of size {size}") 64 | 65 | selected_dialogue_ids_set = set(selected_dialogue_ids) 66 | return [datum for datum in data if datum.dialogue_id in selected_dialogue_ids_set] 67 | 68 | 69 | def create_benchclamp_splits( 70 | train_data: List[BenchClampDatum], 71 | dev_data: List[BenchClampDatum], 72 | test_data: Optional[List[BenchClampDatum]], 73 | output_dir: Path, 74 | ): 75 | """ 76 | Sample splits for BenchClamp experiments. 77 | 1. 5 low data train splits of size 500, single dev set of size 50 78 | 2. 3 medium data train splits of size 5000, single dev set of size 500 79 | 3. Full data split. Reuses dev set for medium data. 80 | 4. Single test split of size 2000. 81 | """ 82 | random.seed(0) 83 | if test_data is None: 84 | train_dialogue_ids = sorted({datum.dialogue_id for datum in train_data}) # type: ignore 85 | random.shuffle(train_dialogue_ids) 86 | test_data = dev_data 87 | num_train_dialogues = len(train_dialogue_ids) 88 | dev_dialogue_ids = set(train_dialogue_ids[: int(0.1 * num_train_dialogues)]) 89 | dev_data = [ 90 | datum for datum in train_data if datum.dialogue_id in dev_dialogue_ids 91 | ] 92 | train_data = [ 93 | datum for datum in train_data if datum.dialogue_id not in dev_dialogue_ids 94 | ] 95 | 96 | print( 97 | f"Input sizes for creating benchclamp splits: " 98 | f"Train {len(train_data)}, Dev {len(dev_data)}, Test {len(test_data)}" 99 | ) 100 | train_turn_ids = [ 101 | (datum.dialogue_id, datum.turn_part_index) for datum in train_data 102 | ] 103 | dev_turn_ids = [(datum.dialogue_id, datum.turn_part_index) for datum in dev_data] 104 | test_turn_ids = [(datum.dialogue_id, datum.turn_part_index) for datum in test_data] 105 | assert ( 106 | len(set(train_turn_ids)) == len(train_turn_ids) 107 | and len(set(dev_turn_ids)) == len(dev_turn_ids) 108 | and len(set(test_turn_ids)) == len(test_turn_ids) 109 | ), "Multiple data points have same data id, make sure all input data have unique data ids." 110 | 111 | output_dir.mkdir(parents=True, exist_ok=True) 112 | save_jsonl_file(test_data, str(output_dir / "test_all.jsonl")) 113 | save_jsonl_file(gather_subset(test_data, 2000), str(output_dir / "test.jsonl")) 114 | 115 | save_jsonl_file(dev_data, str(output_dir / "dev_all.jsonl")) 116 | save_jsonl_file(gather_subset(dev_data, 500), str(output_dir / "dev_medium.jsonl")) 117 | save_jsonl_file(gather_subset(dev_data, 50), str(output_dir / "dev_low.jsonl")) 118 | 119 | save_jsonl_file(train_data, str(output_dir / "train_all.jsonl")) 120 | for split_size_category, num_splits, train_split_size in [ 121 | ("low", 5, 500), 122 | ("medium", 3, 5000), 123 | ]: 124 | for split_id in range(num_splits): 125 | train_subset_file = ( 126 | output_dir / f"train_{split_size_category}_{split_id}.jsonl" 127 | ) 128 | train_subset = gather_subset(train_data, train_split_size) 129 | save_jsonl_file(train_subset, str(train_subset_file)) 130 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/ltl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import json 5 | from dataclasses import dataclass 6 | from enum import Enum 7 | from typing import Callable, Dict, List 8 | 9 | from blobfile import BlobFile 10 | 11 | from semantic_parsing_with_constrained_lm.util.trie import Trie 12 | from semantic_parsing_with_constrained_lm.util.types import StrPath 13 | from semantic_parsing_with_constrained_lm.datum import Datum, FullDatum 14 | from semantic_parsing_with_constrained_lm.decoding.trie_partial_parse import TriePartialParse 15 | 16 | from semantic_parsing_with_constrained_lm.eval import TopKExactMatch 17 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 18 | 19 | # NOTE: get rid of the catflow/dataflow dependency 20 | from appdirs import user_cache_dir 21 | CACHE_DIR = user_cache_dir("semantic_parsing_as_constrained_lm") 22 | 23 | 24 | class LTLOutputType(str, Enum): 25 | Utterance = "utterance" 26 | MeaningRepresentation = "meaningRepresentation" 27 | 28 | 29 | @dataclass 30 | class TopKDenotationMatch(TopKExactMatch[FullDatum]): 31 | canonical_to_denotation: Dict[str, str] 32 | 33 | def _is_correct(self, pred: str, datum: FullDatum) -> bool: 34 | target = datum.canonical 35 | pred_denotation = self.canonical_to_denotation.get(pred) 36 | target_denotation = self.canonical_to_denotation.get(target, None) 37 | if pred_denotation is None and target_denotation is None: 38 | return pred == target 39 | else: 40 | return pred_denotation == target_denotation 41 | 42 | 43 | @dataclass 44 | class LTLPieces: 45 | train_data: List[FullDatum] 46 | test_data: List[FullDatum] 47 | partial_parse_builder: Callable[[Datum], TriePartialParse] 48 | denotation_metric: TopKDenotationMatch 49 | max_length: int 50 | 51 | @staticmethod 52 | def from_dir( 53 | tokenizer: ClampTokenizer, 54 | root_dir: StrPath, 55 | domain: str, 56 | is_dev: bool, 57 | k: int, 58 | output_type: LTLOutputType = LTLOutputType.Utterance, 59 | simplify_logical_forms=False, 60 | prefix_with_space=False, 61 | ) -> "LTLPieces": 62 | data_pieces = LTLDataPieces.from_dir( 63 | root_dir, domain, is_dev, output_type, simplify_logical_forms 64 | ) 65 | decoder_pieces = LTLDecoderPieces.create( 66 | data_pieces, tokenizer, k, prefix_with_space 67 | ) 68 | 69 | return LTLPieces( 70 | data_pieces.train_data, 71 | data_pieces.test_data, 72 | # https://github.com/python/mypy/issues/5485 73 | decoder_pieces.partial_parse_builder, # type: ignore 74 | decoder_pieces.denotation_metric, 75 | decoder_pieces.max_length, 76 | ) 77 | 78 | 79 | @dataclass 80 | class LTLDataPieces: 81 | train_data: List[FullDatum] 82 | test_data: List[FullDatum] 83 | target_output_to_denotation: Dict[str, str] 84 | 85 | @staticmethod 86 | def from_dir( 87 | root_dir: StrPath, 88 | domain: str, 89 | is_dev: bool, 90 | output_type: LTLOutputType = LTLOutputType.MeaningRepresentation, 91 | simplify_logical_forms: bool = False, 92 | ) -> "LTLDataPieces": 93 | with BlobFile(str(root_dir) + f"/{domain}.canonical.json") as bf: 94 | canonical_data = json.load(bf) 95 | 96 | if output_type == LTLOutputType.Utterance: 97 | target_output_to_denotation = { 98 | k: "DO NOT NEED" for k, v in canonical_data.items() 99 | } 100 | datum_key = "canonical" 101 | else: 102 | raise ValueError(output_type) 103 | 104 | train_data, test_data = [ 105 | [ 106 | FullDatum( 107 | dialogue_id=f"{dataset_name}-{i}", 108 | turn_part_index=None, 109 | agent_context=None, 110 | natural=d["natural"], 111 | canonical=d[datum_key] 112 | if simplify_logical_forms 113 | else d[datum_key], 114 | ) 115 | for i, line in enumerate( 116 | BlobFile(path, streaming=False, cache_dir=CACHE_DIR) 117 | ) 118 | for d in [json.loads(line)] 119 | ] 120 | for dataset_name, path in ( 121 | ( 122 | "train", 123 | f"{root_dir}/{domain}.train.jsonl", 124 | ), 125 | ("eval", f"{root_dir}/{domain}.test.jsonl"), 126 | ) 127 | ] 128 | 129 | return LTLDataPieces(train_data, test_data, target_output_to_denotation) 130 | 131 | 132 | @dataclass 133 | class LTLDecoderPieces: 134 | data_pieces: LTLDataPieces 135 | partial_parse_builder: Callable[[Datum], TriePartialParse] 136 | denotation_metric: TopKDenotationMatch 137 | max_length: int 138 | 139 | @staticmethod 140 | def create( 141 | data_pieces: LTLDataPieces, 142 | tokenizer: ClampTokenizer, 143 | k: int, 144 | prefix_with_space: bool = False, 145 | ) -> "LTLDecoderPieces": 146 | if prefix_with_space: 147 | canonical_trie = Trie( 148 | tokenizer.encode(" " + canon) 149 | for canon in data_pieces.target_output_to_denotation 150 | ) 151 | else: 152 | canonical_trie = Trie( 153 | tokenizer.encode(canon) 154 | for canon in data_pieces.target_output_to_denotation 155 | ) 156 | 157 | def partial_parse_builder(_): return TriePartialParse(canonical_trie) 158 | 159 | denotation_metric = TopKDenotationMatch( 160 | k, data_pieces.target_output_to_denotation 161 | ) 162 | max_length = max(len(x) for x in canonical_trie) 163 | 164 | return LTLDecoderPieces( 165 | data_pieces, partial_parse_builder, denotation_metric, max_length 166 | ) 167 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/scfg/parser/parse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from pathlib import Path 5 | from typing import Any, List, Tuple, Union, cast 6 | 7 | from lark import Lark, Transformer, v_args 8 | from lark.exceptions import UnexpectedEOF # type: ignore[attr-defined] 9 | 10 | from semantic_parsing_with_constrained_lm.scfg.parser.macro import Macro 11 | from semantic_parsing_with_constrained_lm.scfg.parser.rule import ( 12 | PlanRule, 13 | Rule, 14 | SyncRule, 15 | UtteranceRule, 16 | mirrored_rule, 17 | ) 18 | from semantic_parsing_with_constrained_lm.scfg.parser.token import ( 19 | EmptyToken, 20 | MacroToken, 21 | NonterminalToken, 22 | OptionableSCFGToken, 23 | RegexToken, 24 | SCFGToken, 25 | TerminalToken, 26 | ) 27 | from semantic_parsing_with_constrained_lm.scfg.parser.types import Expansion 28 | 29 | 30 | def parse_string(parser: Lark, string: str) -> Any: # -> Union[Rule, Macro]: 31 | """ 32 | Parse a string into a Rule or an Expansion. 33 | 34 | We annotate the return type as Any because this can return a Rule, Macro, or Expansion, and 35 | it seems silly to do a cast at each call site. 36 | """ 37 | try: 38 | return TreeToRule().transform(parser.parse(string)) 39 | except UnexpectedEOF as e: 40 | raise Exception(f"Line could not be parsed: {string}") from e 41 | 42 | 43 | class TreeToRule(Transformer): 44 | @v_args(inline=True) 45 | def start(self, arg) -> Union[Macro, Rule]: 46 | return arg 47 | 48 | @v_args(inline=True) 49 | def start_for_test(self, arg) -> Union[Macro, Expansion, Rule]: 50 | return arg 51 | 52 | @v_args(inline=True) 53 | def terminal(self, underlying) -> TerminalToken: 54 | return TerminalToken(underlying.value, optional=False) 55 | 56 | @v_args(inline=True) 57 | def optional_terminal(self, underlying) -> TerminalToken: 58 | return TerminalToken(underlying.value, optional=True) 59 | 60 | @v_args(inline=True) 61 | def nonterminal(self, name_token) -> NonterminalToken: 62 | return NonterminalToken(name_token.value, optional=False) 63 | 64 | @v_args(inline=True) 65 | def optional_nonterminal(self, name_token) -> NonterminalToken: 66 | return NonterminalToken(name_token.value, optional=True) 67 | 68 | @v_args(inline=True) 69 | def empty(self) -> EmptyToken: 70 | return EmptyToken() 71 | 72 | @v_args(inline=True) 73 | def regex(self, arg) -> RegexToken: 74 | return RegexToken(arg.value, optional=False, prefix="") 75 | 76 | def plan_expansion(self, args) -> Expansion: 77 | return tuple(args) 78 | 79 | def utterance_expansion(self, args) -> Expansion: 80 | return tuple(args) 81 | 82 | def utterance_expansions(self, no_macro_expansions) -> Tuple[Expansion, ...]: 83 | return tuple(no_macro_expansions) 84 | 85 | @v_args(inline=True) 86 | def token(self, arg: SCFGToken) -> SCFGToken: 87 | return arg 88 | 89 | @v_args(inline=True) 90 | def rule(self, name_token) -> str: 91 | return name_token.value 92 | 93 | @v_args(inline=True) 94 | def sync_rule( 95 | self, lhs: str, expansions: List[Expansion], expansion: Expansion 96 | ) -> SyncRule: 97 | return SyncRule(lhs=lhs, utterance_rhss=tuple(expansions), plan_rhs=expansion) 98 | 99 | @v_args(inline=True) 100 | def mirrored_rule(self, lhs: str, rhs: Expansion) -> SyncRule: 101 | return mirrored_rule(lhs, rhs) 102 | 103 | @v_args(inline=True) 104 | def utterance_rule(self, rule, expansions) -> UtteranceRule: 105 | return UtteranceRule(rule, tuple(expansions)) 106 | 107 | @v_args(inline=True) 108 | def macro_rule(self, macro_def, expansion) -> Macro: 109 | return Macro(macro_def[0], macro_def[1], expansion) 110 | 111 | def macro_def(self, args) -> Tuple[str, Tuple[str, ...]]: 112 | return cast(str, args[0].value), tuple(cast(str, a.value) for a in args[1:]) 113 | 114 | def macro_apply(self, args) -> MacroToken: 115 | return MacroToken(args[0].value, tuple(args[1:])) 116 | 117 | 118 | def get_scfg_parser(start_symbol: str = "start") -> Lark: 119 | """ 120 | Get a parser based on the SCFG grammar. The start rule that gets appended to the grammar 121 | at the end depends on whether we are testing or not. If we are testing, then we want to be 122 | able to parse expansions outside of rules so that in our tests, we don't have to write 123 | lists of tokens. 124 | """ 125 | scfg_grammar_path = Path(__file__).parent / "scfg_grammar.lark" 126 | scfg_grammar: str 127 | with open(scfg_grammar_path, "r") as cf_grammar_file: 128 | scfg_grammar = cf_grammar_file.read() 129 | 130 | # Type ignoring because mypy doesn't play well with Lark. 131 | return Lark(scfg_grammar, ambiguity="explicit", start=start_symbol) # type: ignore 132 | 133 | 134 | # RENDERING 135 | 136 | 137 | def render_token(token: SCFGToken) -> str: 138 | if isinstance(token, EmptyToken): 139 | return "#e" 140 | elif isinstance(token, OptionableSCFGToken): 141 | optional_str = "?" if token.optional else "" 142 | value = token.lark_value if isinstance(token, RegexToken) else token.value 143 | return value + optional_str 144 | else: 145 | assert isinstance(token, MacroToken) 146 | return token.value 147 | 148 | 149 | def render_expansion(rhs: Expansion) -> str: 150 | return " ".join(render_token(t) for t in rhs) 151 | 152 | 153 | def render_expansions(expansions: Tuple[Expansion, ...]): 154 | return " | ".join(render_expansion(rhs) for rhs in expansions) 155 | 156 | 157 | def render_rule(rule: Union[Rule, Macro]) -> str: 158 | if isinstance(rule, Macro): 159 | arg_str = f"({', '.join(rule.args)})" if rule.args else "" 160 | return f"{rule.name}{arg_str} 2> {render_expansion(rule.expansion)}" 161 | elif isinstance(rule, PlanRule): 162 | return f"{rule.lhs} 2> {render_expansion(rule.rhs)}" 163 | elif isinstance(rule, UtteranceRule): 164 | return f"{rule.lhs} 1> {render_expansions(rule.utterance_rhss)}" 165 | else: 166 | assert isinstance(rule, SyncRule) 167 | return f"{rule.lhs} -> {render_expansions(rule.utterance_rhss)} , {render_expansion(rule.plan_rhs)}" 168 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/calflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import dataclasses 5 | import json 6 | import signal 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from itertools import islice 10 | from operator import itemgetter 11 | from typing import Dict, Iterable, List, Optional, Set 12 | 13 | from dataflow.core.lispress import ( 14 | lispress_to_program, 15 | parse_lispress, 16 | program_to_lispress, 17 | render_compact, 18 | try_round_trip, 19 | ) 20 | from lark import GrammarError, ParseError, UnexpectedCharacters 21 | 22 | from semantic_parsing_with_constrained_lm.util.types import StrPath 23 | from semantic_parsing_with_constrained_lm.scfg.generate import parse_and_render 24 | from semantic_parsing_with_constrained_lm.scfg.scfg import SCFG 25 | from semantic_parsing_with_constrained_lm.datum import FullDatum 26 | from semantic_parsing_with_constrained_lm.domains.calflow.disambiguate import score_auto_grammar_plan 27 | from semantic_parsing_with_constrained_lm.eval import TopKExactMatch 28 | 29 | 30 | class CalflowOutputLanguage(str, Enum): 31 | Canonical = "canonicalUtterance" 32 | Lispress = "lispress" 33 | 34 | 35 | @dataclass(frozen=True) 36 | class CalflowDatum(FullDatum): 37 | lispress: str 38 | 39 | 40 | def read_calflow_jsonl( 41 | filename: StrPath, 42 | model_output_type: CalflowOutputLanguage, 43 | whitelisted_dialogue_ids: Optional[Set[str]] = None, 44 | ) -> List[CalflowDatum]: 45 | """ 46 | Reads CalflowDatum lists from file `filename` with `canonical` based on model_output_type. Selects based on 47 | whitelisted_dialogue_ids when set, reads all data otherwise. 48 | """ 49 | with open(filename) as test_file: 50 | return [ 51 | CalflowDatum( 52 | agent_context=json_line.get("context", ""), 53 | natural=json_line["utterance"], 54 | canonical=json_line[model_output_type], 55 | dialogue_id=json_line["dialogueId"], 56 | turn_part_index=json_line["turnIndex"], 57 | lispress=json_line["lispress"], 58 | ) 59 | for line in test_file 60 | for json_line in [json.loads(line)] 61 | if whitelisted_dialogue_ids is None 62 | or json_line["dialogueId"] in whitelisted_dialogue_ids 63 | ] 64 | 65 | 66 | def predict_plan_from_canonical( 67 | scfg: SCFG, 68 | utterance: str, 69 | k: int = 1000, 70 | max_depth: int = 15, 71 | fallback_plan: str = " (FenceScope)", 72 | ) -> str: 73 | """ 74 | Predicts a single Lispress surface string from the given canonical 75 | `utterance`. 76 | Finds possible parses using `scfg`, truncates to the top `k`, then picks 77 | the highest scoring possible parse under `score_auto_grammar_plan`. 78 | """ 79 | unscored_plans: Iterable[str] 80 | try: 81 | unscored_plans = parse_and_render( 82 | scfg, utterance, source_is_plan=False, max_depth=max_depth 83 | ) 84 | except (GrammarError, UnexpectedCharacters, ParseError, AttributeError): 85 | unscored_plans = [] 86 | try: 87 | scored_plans = ( 88 | (plan, score_auto_grammar_plan(plan)) for plan in unscored_plans 89 | ) 90 | filtered_plans = ( 91 | (plan, score) for plan, score in scored_plans if score != -float("inf") 92 | ) 93 | truncated_plans = islice(filtered_plans, k) 94 | except AttributeError: 95 | truncated_plans = iter([]) 96 | try: 97 | best_plan, _best_score = max(truncated_plans, key=itemgetter(1)) 98 | except ValueError: 99 | # no candidates 100 | best_plan = fallback_plan 101 | 102 | # Remove leading space 103 | # TODO: Remove need for this by removing the space from the grammar 104 | assert len(best_plan) > 0 and best_plan[0] == " " 105 | best_plan = best_plan[1:] 106 | 107 | return best_plan 108 | 109 | 110 | class ParseTimeout(Exception): 111 | pass 112 | 113 | 114 | @dataclass 115 | class CalflowMetrics(TopKExactMatch[CalflowDatum]): 116 | scfg: SCFG 117 | data_type: CalflowOutputLanguage 118 | 119 | cached_parses: Dict[str, str] = dataclasses.field(default_factory=dict) 120 | 121 | # Only attempt to convert predictions with the same length as the gold, 122 | # which saves a lot of time with parsing. 123 | require_exact_length: bool = False 124 | 125 | @staticmethod 126 | def parse_timeout_handler(sig, frame): 127 | raise ParseTimeout 128 | 129 | def cached_parse(self, pred: str, gold: Optional[str]) -> str: 130 | """ 131 | Given a canonical utterance, convert it into a lispress plan. 132 | """ 133 | if pred in self.cached_parses: 134 | return self.cached_parses[pred] 135 | if self.require_exact_length and (gold is None or len(pred) != len(gold)): 136 | return "(FenceScope)" 137 | 138 | prev_signal = signal.getsignal(signal.SIGALRM) 139 | if prev_signal == signal.SIG_DFL: 140 | signal.signal(signal.SIGALRM, self.parse_timeout_handler) 141 | signal.alarm(300) 142 | 143 | try: 144 | predicted_plan = predict_plan_from_canonical(self.scfg, " " + pred) 145 | signal.signal(signal.SIGALRM, prev_signal) 146 | except Exception as e: # pylint: disable=broad-except 147 | print(e) 148 | predicted_plan = "(FenceScope)" 149 | finally: 150 | signal.signal(signal.SIGALRM, prev_signal) 151 | 152 | self.cached_parses[pred] = predicted_plan 153 | return predicted_plan 154 | 155 | def _is_correct(self, pred: str, target: CalflowDatum) -> bool: 156 | if self.data_type == CalflowOutputLanguage.Canonical: 157 | predicted_plan = self.cached_parse(pred, target.canonical) 158 | else: 159 | try: 160 | # round-trip to canonicalize 161 | predicted_plan = render_compact( 162 | program_to_lispress(lispress_to_program(parse_lispress(pred), 0)[0]) 163 | ) 164 | except Exception: # pylint: disable=W0703 165 | predicted_plan = "(FenceScope)" 166 | 167 | return try_round_trip(target.lispress) == try_round_trip(predicted_plan) 168 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/create_benchclamp_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import dataclasses 5 | from pathlib import Path 6 | from typing import Dict, List 7 | 8 | import tqdm 9 | from transformers import GPT2Tokenizer 10 | 11 | from semantic_parsing_with_constrained_lm.scfg.scfg import SCFG 12 | from semantic_parsing_with_constrained_lm.datum import BenchClampDatum 13 | from semantic_parsing_with_constrained_lm.decoding.earley_partial_parse import ( 14 | GrammarTokenizerInfo, 15 | UTF8EarleyPartialParse, 16 | ) 17 | from semantic_parsing_with_constrained_lm.domains.benchclamp_data_setup import BenchClampDataset 18 | from semantic_parsing_with_constrained_lm.domains.create_benchclamp_splits import ( 19 | can_force_decode, 20 | create_benchclamp_splits, 21 | ) 22 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.dialogue import ( 23 | canonicalize_sql_with_grammar, 24 | convert_cosql_to_datum_format, 25 | convert_spider_to_datum_format, 26 | load_cosql_data, 27 | load_spider_data, 28 | ) 29 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.grammar import ( 30 | load_base_grammar, 31 | preprocessed_grammar_for_schema, 32 | ) 33 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.schema import DbSchema, load_schemas 34 | from semantic_parsing_with_constrained_lm.domains.sql.sql_datum import SqlDatum 35 | from semantic_parsing_with_constrained_lm.paths import ( 36 | BENCH_CLAMP_PROCESSED_DATA_DIR, 37 | BENCH_CLAMP_RAW_DATA_DIR, 38 | ) 39 | from semantic_parsing_with_constrained_lm.tokenization import GPT2ClampTokenizer 40 | 41 | 42 | def write_data_and_test_grammar( 43 | train_data: List[BenchClampDatum], 44 | dev_data: List[BenchClampDatum], 45 | schemas: Dict[str, DbSchema], 46 | datum_output_dir: Path, 47 | ) -> None: 48 | base_grammar = load_base_grammar() 49 | pre_grammars = { 50 | name: preprocessed_grammar_for_schema(db, base_grammar) 51 | for name, db in schemas.items() 52 | } 53 | grammars = {name: SCFG(pg) for name, pg in pre_grammars.items()} 54 | train_data_with_canonical_sql: List[BenchClampDatum] = [] 55 | dev_data_with_canonical_sql: List[BenchClampDatum] = [] 56 | print("Canonicalizing SQL ...") 57 | for data, data_with_canonical_sql in [ 58 | (train_data, train_data_with_canonical_sql), 59 | (dev_data, dev_data_with_canonical_sql), 60 | ]: 61 | for datum in tqdm.tqdm(data): 62 | grammar = grammars[datum.schema_name] 63 | data_with_canonical_sql.append( 64 | dataclasses.replace( 65 | datum, plan=canonicalize_sql_with_grammar(datum.plan, grammar) 66 | ) 67 | ) 68 | 69 | # Create data splits 70 | print("Creating data splits ...") 71 | create_benchclamp_splits( 72 | train_data_with_canonical_sql, 73 | dev_data_with_canonical_sql, 74 | None, 75 | datum_output_dir, 76 | ) 77 | 78 | print("Testing ...") 79 | clamp_tokenizer = GPT2ClampTokenizer(GPT2Tokenizer.from_pretrained("gpt2")) 80 | grammar_tok_info = { 81 | name: GrammarTokenizerInfo.create(clamp_tokenizer, preprocessed_grammar, True) 82 | for name, preprocessed_grammar in pre_grammars.items() 83 | } 84 | partial_parse_builder = lambda datum: UTF8EarleyPartialParse.initial( 85 | grammar_tok_info[datum.schema_name], datum.natural 86 | ) 87 | total = 0 88 | wrong = 0 89 | print("Testing if force decoding possible for first 100 examples") 90 | for datum in train_data_with_canonical_sql: 91 | if total >= 100: 92 | break 93 | total += 1 94 | if not can_force_decode( 95 | clamp_tokenizer, 96 | partial_parse_builder, 97 | SqlDatum( 98 | natural=datum.utterance, 99 | canonical=datum.plan, 100 | dialogue_id="", 101 | turn_part_index=0, 102 | agent_context="", 103 | schema_name=datum.schema_name, # type: ignore 104 | ), 105 | ): 106 | print(f"Error: {datum.plan}") 107 | print(f"Schema: {datum.schema_name}") 108 | print(f"Utterance: {datum.utterance}") 109 | print() 110 | wrong += 1 111 | 112 | print(f"Force Decode Errors %: {wrong} / {total}") 113 | 114 | 115 | def main(): 116 | raw_spider_dir = BENCH_CLAMP_RAW_DATA_DIR / BenchClampDataset.Spider.value 117 | spider_schemas = load_schemas( 118 | schemas_path=raw_spider_dir / "tables.json", db_path=raw_spider_dir / "database" 119 | ) 120 | spider_train, train_others, spider_dev = [ 121 | convert_spider_to_datum_format( 122 | load_spider_data(raw_spider_dir / fn), 123 | db_map=spider_schemas, 124 | db_path=str(raw_spider_dir / "database"), 125 | ) 126 | for fn in ["train_spider.json", "train_others.json", "dev.json"] 127 | ] 128 | spider_train.extend( 129 | [ 130 | dataclasses.replace(datum, dialogue_id=f"other-{datum.dialogue_id}") 131 | for datum in train_others 132 | ] 133 | ) 134 | write_data_and_test_grammar( 135 | train_data=spider_train, 136 | dev_data=spider_dev, 137 | schemas=spider_schemas, 138 | datum_output_dir=BENCH_CLAMP_PROCESSED_DATA_DIR 139 | / BenchClampDataset.Spider.value, 140 | ) 141 | 142 | raw_cosql_dir = BENCH_CLAMP_RAW_DATA_DIR / BenchClampDataset.CoSQL.value 143 | cosql_schemas = load_schemas( 144 | schemas_path=raw_cosql_dir / "tables.json", db_path=raw_cosql_dir / "database" 145 | ) 146 | cosql_train, cosql_dev = [ 147 | convert_cosql_to_datum_format( 148 | load_cosql_data(raw_cosql_dir / "sql_state_tracking" / fn), 149 | db_map=cosql_schemas, 150 | db_path=str(raw_cosql_dir / "database"), 151 | ) 152 | for fn in ["cosql_train.json", "cosql_dev.json"] 153 | ] 154 | write_data_and_test_grammar( 155 | train_data=cosql_train, 156 | dev_data=cosql_dev, 157 | schemas=cosql_schemas, 158 | datum_output_dir=BENCH_CLAMP_PROCESSED_DATA_DIR / BenchClampDataset.CoSQL.value, 159 | ) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import csv 5 | import dataclasses 6 | from abc import ABC, abstractmethod 7 | from contextlib import AbstractContextManager 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | from typing import Dict, Generic, List, Optional, Sequence, TextIO, TypeVar 11 | 12 | from semantic_parsing_with_constrained_lm.datum import FullDatum, FullDatumSub 13 | from semantic_parsing_with_constrained_lm.model import ModelResult 14 | 15 | Pred = TypeVar("Pred") 16 | Target = TypeVar("Target") 17 | 18 | 19 | # TODO: Replcae this with a more flexible function suited to each domain 20 | def exact_match_with_logging( 21 | test_datum: FullDatum, kbest: Sequence[ModelResult] 22 | ) -> bool: 23 | gold = ( 24 | test_datum.canonical.strip(" ") 25 | if test_datum.canonical is not None 26 | else "UNREACHABLE" 27 | ) 28 | pred = kbest[0].text.strip(" ") if kbest else "" 29 | print() 30 | print(f"context: {test_datum.agent_context}") 31 | print(f"natural: {test_datum.natural}") 32 | print(f"predicted: {pred}") 33 | print(f"gold: {gold}") 34 | result = gold == pred 35 | print(f"is correct: {result}") 36 | beam_result = False 37 | for i, pred_i in enumerate(kbest): 38 | stripped = pred_i.text.strip(" ") 39 | beam_result = beam_result or gold == stripped 40 | print(f"Beam {i} [{pred_i.cost:.3f}]: {stripped}") 41 | print(f"is correct@{i}: {beam_result}") 42 | print() 43 | return result 44 | 45 | 46 | class Metric(Generic[Pred, Target], ABC): 47 | """Used to measure goodness of model results compared to the ground truth. 48 | 49 | Stateful over the duration of an experiment run.""" 50 | 51 | @abstractmethod 52 | def update(self, pred: Pred, target: Target) -> Dict[str, Optional[str]]: 53 | """Uses `target` and the model predictions `pred` to update the state.""" 54 | pass 55 | 56 | @abstractmethod 57 | def compute(self) -> Dict[str, float]: 58 | """Uses the state to compute the final results.""" 59 | pass 60 | 61 | @abstractmethod 62 | def reset(self) -> None: 63 | """Reinitializes the state.""" 64 | pass 65 | 66 | 67 | @dataclass 68 | class TopKExactMatch(Metric[Sequence[str], FullDatumSub]): 69 | k: int 70 | correct: List[int] = dataclasses.field(init=False) 71 | total: int = dataclasses.field(init=False) 72 | 73 | def __post_init__(self): 74 | self.reset() 75 | 76 | def _is_correct(self, pred: str, target: FullDatumSub) -> bool: 77 | """Can be overridden by child classes.""" 78 | return pred == target.canonical 79 | 80 | def update( 81 | self, preds: Sequence[str], target: FullDatumSub 82 | ) -> Dict[str, Optional[str]]: 83 | self.total += 1 84 | found_correct = False 85 | result: Dict[str, Optional[str]] = {} 86 | for i, pred in enumerate(preds[: self.k]): 87 | correct = self._is_correct(pred, target) 88 | found_correct |= correct 89 | self.correct[i] += found_correct 90 | result[f"rank{i + 1}"] = "correct" if correct else "incorrect" 91 | result[f"top{i + 1}"] = "correct" if found_correct else "incorrect" 92 | 93 | # Handle when we have fewer predictions than self.k 94 | for i in range(len(preds), self.k): 95 | self.correct[i] += found_correct 96 | result[f"rank{i + 1}"] = "incorrect" 97 | result[f"top{i + 1}"] = "correct" if found_correct else "incorrect" 98 | 99 | return result 100 | 101 | def compute(self) -> Dict[str, float]: 102 | result = {} 103 | for i in range(self.k): 104 | result[f"top{i + 1}"] = self.correct[i] / self.total 105 | return result 106 | 107 | def reset(self) -> None: 108 | self.correct = [0] * self.k 109 | self.total = 0 110 | 111 | 112 | class Logger(Generic[Pred, Target], AbstractContextManager): 113 | """Experiment logger interface for capturing model outputs for a 114 | given target and logging them in some form. Useful for things like 115 | producing error tables. 116 | 117 | The logger implements context manager and must be wrapped in a 118 | `with` statement to be used. 119 | """ 120 | 121 | @abstractmethod 122 | def log(self, predictions: Pred, target: Target, metrics: Dict[str, Optional[str]]): 123 | pass 124 | 125 | 126 | class ExactMatchErrorLogger(Logger[Sequence[ModelResult], FullDatumSub]): 127 | """A logger that logs in 2 files: 128 | - base_path / errors.csv - CSV with model error using exact match metric. 129 | - base_path / correct.scv - CSV with correct model predictions. 130 | """ 131 | 132 | def __init__(self, base_path: Path): 133 | self._base_path = base_path 134 | self._err_file: Optional[TextIO] = None 135 | self._err_writer = None 136 | self._corr_file: Optional[TextIO] = None 137 | self._corr_writer = None 138 | 139 | def __enter__(self): 140 | assert not self._err_file and not self._corr_file 141 | self._err_file = open(self._base_path / "errors.csv", "w") 142 | self._err_writer = csv.writer( 143 | self._err_file, delimiter=",", quoting=csv.QUOTE_MINIMAL 144 | ) 145 | self._corr_file = open(self._base_path / "correct.csv", "w") 146 | self._corr_writer = csv.writer( 147 | self._corr_file, delimiter=",", quoting=csv.QUOTE_MINIMAL 148 | ) 149 | 150 | self._err_writer.writerow(["Natural", "Gold", "Predicted"]) # type: ignore 151 | self._corr_writer.writerow(["Natural", "Gold"]) # type: ignore 152 | 153 | def __exit__(self, *_): 154 | assert self._err_file and self._corr_file 155 | self._err_file.close() 156 | self._corr_file.close() 157 | 158 | def log( 159 | self, 160 | predictions: Sequence[ModelResult], 161 | target: FullDatumSub, 162 | metrics: Dict[str, Optional[str]], 163 | ): 164 | pred = predictions[0].text if len(predictions) else "" 165 | if pred != target.canonical: 166 | self._err_writer.writerow([target.natural, target.canonical, pred]) # type: ignore 167 | self._err_file.flush() # type: ignore 168 | else: 169 | self._corr_writer.writerow([target.natural, target.canonical]) # type: ignore 170 | self._corr_file.flush() # type: ignore 171 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/sql/cosql/schema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | from collections import defaultdict 6 | from dataclasses import dataclass 7 | from enum import Enum 8 | from json import load 9 | from typing import Any, Dict, List, Tuple 10 | 11 | import jsons 12 | 13 | from semantic_parsing_with_constrained_lm.util.types import StrPath 14 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.content_encoder import get_column_picklist 15 | from semantic_parsing_with_constrained_lm.domains.sql.cosql.paths import COSQL_DIR, SCHEMAS_FILE 16 | 17 | 18 | class ColumnType(Enum): 19 | Text = "text" 20 | Number = "number" 21 | Time = "time" 22 | Boolean = "boolean" 23 | Others = "others" 24 | 25 | 26 | @dataclass(frozen=True) 27 | class Column: 28 | name: str 29 | tpe: ColumnType 30 | # name in more natural language 31 | nl_name: str = "" 32 | 33 | @staticmethod 34 | def star() -> "Column": 35 | return Column(name="*", tpe=ColumnType.Text, nl_name="*") 36 | 37 | 38 | @dataclass(frozen=True) 39 | class ForeignKey: 40 | column_id: int 41 | other_column_id: int 42 | 43 | 44 | @dataclass(frozen=True) 45 | class Table: 46 | name: str 47 | columns: List[Column] 48 | # name in more natural language 49 | nl_name: str = "" 50 | 51 | def all_columns(self) -> List[Column]: 52 | return [Column.star()] + self.columns 53 | 54 | 55 | @dataclass(frozen=True) 56 | class DbSchema: 57 | name: str 58 | tables: List[Table] 59 | # indexes into self.tables 60 | columns: List[Tuple[int, Column]] = () # type: ignore 61 | # indexes into self.columns 62 | primary_keys: List[int] = () # type: ignore 63 | # indexes into self.columns 64 | foreign_keys: List[ForeignKey] = () # type: ignore 65 | # values in the database 66 | values: List[str] = () # type: ignore 67 | 68 | @staticmethod 69 | def from_json(schema_json: Dict[str, Any], db_path: str) -> "DbSchema": 70 | columns: List[Tuple[int, Column]] = [ 71 | (t_id, Column(name=orig, tpe=ColumnType(tpe), nl_name=name)) 72 | for (t_id, orig), (_, name), tpe in zip( 73 | schema_json["column_names_original"], 74 | schema_json["column_names"], 75 | schema_json["column_types"], 76 | ) 77 | ] 78 | columns_by_table: Dict[int, List[Column]] = defaultdict(list) 79 | for t_id, col in columns: 80 | columns_by_table[t_id].append(col) 81 | # TODO: 82 | # tables.json is corrupted in the CoSQL dataset for schema formula_1. 83 | # "table_names" and "table_names_original" are not collated: 84 | # "table_names": [ "races", "drivers", "status", "seasons", "constructors", 85 | # "constructor standings", "results", "driver standings", "constructor results", "qualifying", 86 | # "circuits", "pitstops", "laptimes" ], 87 | # "table_names_original": [ "circuits", "races", "drivers", "status", "seasons", "constructors", 88 | # "constructorStandings", "results", "driverStandings", "constructorResults", "qualifying", 89 | # "pitStops", "lapTimes" ] 90 | # in that case, nl_name will be messed up for the table 91 | tables = [ 92 | Table(name=orig, columns=columns_by_table[t_id], nl_name=name) 93 | for t_id, (orig, name) in enumerate( 94 | zip(schema_json["table_names_original"], schema_json["table_names"]) 95 | ) 96 | ] 97 | foreign_keys = [ 98 | ForeignKey(col, other) for col, other in schema_json["foreign_keys"] 99 | ] 100 | db_id = schema_json["db_id"] 101 | values = [] 102 | # Some tables mentioned in tables,json are not present in the download 103 | # This check catches errors when trying to read them. 104 | if os.path.exists(db_path + "/" + db_id): 105 | db_path = db_path + "/" + db_id + "/" + db_id + ".sqlite" 106 | for table in tables: 107 | for column in table.columns: 108 | picklist = get_column_picklist(table.name, column.name, db_path) 109 | values.extend( 110 | [ 111 | val 112 | for val in picklist 113 | # this condition removes times from the list 114 | if isinstance(val, str) and not val.count(":") == 2 115 | ] 116 | ) 117 | 118 | return DbSchema( 119 | name=schema_json["db_id"], 120 | columns=columns, 121 | foreign_keys=foreign_keys, 122 | primary_keys=schema_json["primary_keys"], 123 | tables=tables, 124 | values=values, 125 | ) 126 | 127 | 128 | def load_schemas( 129 | schemas_path: StrPath = SCHEMAS_FILE, db_path: StrPath = COSQL_DIR / "database" 130 | ) -> Dict[str, DbSchema]: 131 | db_schema_details_file = str(db_path) + "/" + "db_schema_details.json" 132 | if os.path.exists(db_schema_details_file): 133 | with open(db_schema_details_file, "r") as db_schema_details_fp: 134 | return jsons.loads(db_schema_details_fp.read(), cls=Dict[str, DbSchema]) 135 | 136 | with open(schemas_path) as tables_file: 137 | schemas_json = load(tables_file) 138 | schemas = [ 139 | DbSchema.from_json(schema_json, str(db_path)) for schema_json in schemas_json 140 | ] 141 | return {schema.name: schema for schema in schemas} 142 | 143 | 144 | if __name__ == "__main__": 145 | cosql_schemas = load_schemas( 146 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/CoSQL/tables.json", 147 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/CoSQL/database/", 148 | ) 149 | with open( 150 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/CoSQL/database/db_schema_details.json", 151 | "w", 152 | ) as fp: 153 | fp.write(jsons.dumps(cosql_schemas, cls=Dict[str, DbSchema])) 154 | 155 | spider_schemas = load_schemas( 156 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/Spider/tables.json", 157 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/Spider/database/", 158 | ) 159 | with open( 160 | "/Users/subhrroy/workspace/pyharbor/data/benchclamp/raw/Spider/database/db_schema_details.json", 161 | "w", 162 | ) as fp: 163 | fp.write(jsons.dumps(cosql_schemas, cls=Dict[str, DbSchema])) 164 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/configs/lib/calflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import functools 5 | from typing import Callable, Dict, List, Optional, Tuple 6 | 7 | from transformers.tokenization_utils import PreTrainedTokenizer 8 | 9 | from semantic_parsing_with_constrained_lm.scfg.read_grammar import PreprocessedGrammar 10 | from semantic_parsing_with_constrained_lm.configs.lib.common import ( 11 | PromptOrder, 12 | SeparateLM, 13 | make_semantic_parser, 14 | ) 15 | from semantic_parsing_with_constrained_lm.datum import Datum 16 | from semantic_parsing_with_constrained_lm.decoding.earley_partial_parse import ( 17 | GrammarTokenizerInfo, 18 | UTF8EarleyPartialParse, 19 | ) 20 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import StartsWithSpacePartialParse 21 | from semantic_parsing_with_constrained_lm.domains.calflow import ( 22 | CalflowDatum, 23 | CalflowOutputLanguage, 24 | read_calflow_jsonl, 25 | ) 26 | from semantic_parsing_with_constrained_lm.fewshot import PromptBuilder 27 | from semantic_parsing_with_constrained_lm.lm import ( 28 | AutoregressiveModel, 29 | ClientType, 30 | IncrementalLanguageModel, 31 | ) 32 | from semantic_parsing_with_constrained_lm.model import ( 33 | BeamSearchSemanticParser, 34 | DecodingSetup, 35 | ProblemFactory, 36 | ) 37 | 38 | # This is a magic number computed by Chris, the origins of which are no 39 | # longer exactly known. It should correspond to the maximum number of GPT-2 40 | # tokens we expect any plan would take up, when written as Lispress. 41 | # 42 | # We truncate the number of training examples put inside the prompt so that 43 | # we can add this many more tokens and still stay under the 2048 limit. 44 | # 45 | # When using canonical utterances, this number doesn't need to be as big 46 | # since canonical utterances are not as long as Lispress. However, when 47 | # using 20 training examples per prompt, we are never at risk of reaching 48 | # the 2048 limit, so this point is moot. 49 | MAX_STEPS_FOR_COMPLETION = 313 50 | 51 | 52 | calflow_max_steps_fn_params: Dict[ 53 | Tuple[CalflowOutputLanguage, ClientType], Tuple[float, float] 54 | ] = { 55 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \ 56 | # --data-path \ 57 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \ 58 | # --tokenizer facebook/bart-large --output-type canonicalUtterance \ 59 | # --max-unreachable 3 60 | (CalflowOutputLanguage.Canonical, ClientType.BART): (8, 1.7233333333), 61 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \ 62 | # --data-path \ 63 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \ 64 | # --tokenizer gpt2-xl --output-type canonicalUtterance \ 65 | # --max-unreachable 3 66 | (CalflowOutputLanguage.Canonical, ClientType.GPT3): (8, 1.7233333333), 67 | (CalflowOutputLanguage.Canonical, ClientType.SMGPT3): (8, 1.7233333333), 68 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \ 69 | # --data-path \ 70 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \ 71 | # --tokenizer facebook/bart-large --output-type lispress \ 72 | # --max-unreachable 3 73 | (CalflowOutputLanguage.Lispress, ClientType.BART): (65, 7.084487179487172), 74 | # python semantic_parsing_with_constrained_lm/scripts/calflow_fit_max_steps.py \ 75 | # --data-path \ 76 | # semantic_parsing_with_constrained_lm/domains/calflow/data/train_300_stratified.jsonl \ 77 | # --tokenizer gpt2-xl --output-type lispress --max-unreachable 3 78 | (CalflowOutputLanguage.Lispress, ClientType.GPT3): (65, 7.084487179487172), 79 | (CalflowOutputLanguage.Lispress, ClientType.SMGPT3): (65, 7.084487179487172), 80 | } 81 | 82 | 83 | def get_calflow_max_steps_fn( 84 | output_type: CalflowOutputLanguage, 85 | client_type: ClientType, 86 | tokenizer: PreTrainedTokenizer, 87 | ) -> Callable[[Datum], Optional[int]]: 88 | max_steps_intercept, max_steps_slope = calflow_max_steps_fn_params[ 89 | output_type, client_type 90 | ] 91 | 92 | def fn(datum: Datum) -> Optional[int]: 93 | return min( 94 | int( 95 | len(tokenizer.tokenize(datum.natural)) * max_steps_slope 96 | + max_steps_intercept 97 | ), 98 | MAX_STEPS_FOR_COMPLETION, 99 | ) 100 | 101 | return fn 102 | 103 | 104 | def make_semantic_parser_for_calflow( 105 | train_data: List[CalflowDatum], 106 | lm: AutoregressiveModel, 107 | use_gpt3: bool, 108 | beam_size: int, 109 | output_type: CalflowOutputLanguage, 110 | client_type: ClientType, 111 | preprocessed_grammar: PreprocessedGrammar, 112 | constrained: bool, 113 | prompt_order: PromptOrder = PromptOrder.BestLast, 114 | # Settings for using autoregressive models in a few-shot in-context setting 115 | similarity_lm: Optional[IncrementalLanguageModel] = None, 116 | prompt_builder: Optional[PromptBuilder] = None, 117 | num_examples_per_prompt: int = 20, 118 | problem_factory_builder: Optional[Callable[[DecodingSetup], ProblemFactory]] = None, 119 | ) -> BeamSearchSemanticParser: 120 | if constrained: 121 | grammar_tokenizer_info = GrammarTokenizerInfo.create( 122 | lm.tokenizer, 123 | preprocessed_grammar, 124 | output_type == CalflowOutputLanguage.Lispress, 125 | ) 126 | # TODO: Refer to `lm` to decide whether to use UTF8EarleyPartialParse or a different variant 127 | partial_parse_builder = lambda datum: UTF8EarleyPartialParse.initial( 128 | grammar_tokenizer_info, datum.natural 129 | ) 130 | else: 131 | # TODO: Only impose this if we are using a GPT-2-style tokenizer 132 | partial_parse = StartsWithSpacePartialParse(lm.tokenizer) 133 | partial_parse_builder = lambda _: partial_parse 134 | 135 | max_steps_fn = get_calflow_max_steps_fn(output_type, client_type, lm.tokenizer) 136 | 137 | return make_semantic_parser( 138 | train_data, 139 | lm, 140 | use_gpt3, 141 | MAX_STEPS_FOR_COMPLETION, 142 | beam_size, 143 | partial_parse_builder, 144 | max_steps_fn=max_steps_fn, 145 | prompt_order=prompt_order, 146 | similarity_method=SeparateLM(similarity_lm=similarity_lm), # type: ignore 147 | prompt_builder=prompt_builder, 148 | num_examples_per_prompt=num_examples_per_prompt, 149 | problem_factory_builder=problem_factory_builder, 150 | ) 151 | 152 | 153 | cached_read_calflow_jsonl = functools.lru_cache(maxsize=None)(read_calflow_jsonl) 154 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/decoding/uint8_earley_partial_parse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Constrained decoding for Earley grammars where terminals are np.uint8. 5 | 6 | UInt8EarleyPartialParse is similar to EarleyPartialParse, except that it only 7 | works for grammars where all terminals are np.uint8. 8 | It also currently lacks support for input utterance copying constraints. 9 | 10 | TODO: We can generalize UInt8EarleyPartialParse to "Atomic"EarleyPartialParse by 11 | replacing np.uint8 with a type parameter. 12 | """ 13 | 14 | import dataclasses 15 | import itertools 16 | from dataclasses import dataclass 17 | from typing import Mapping # pylint: disable=unused-import 18 | from typing import ( 19 | Any, 20 | Callable, 21 | Dict, 22 | Iterable, 23 | Iterator, 24 | Optional, 25 | Sequence, 26 | Tuple, 27 | TypeVar, 28 | ) 29 | 30 | import numpy as np 31 | import torch 32 | from cached_property import cached_property 33 | 34 | from semantic_parsing_with_constrained_lm.earley.earley import EarleyChart 35 | from semantic_parsing_with_constrained_lm.earley.grammar import Grammar 36 | from semantic_parsing_with_constrained_lm.earley.input import Position, SigmaStarTriePosition 37 | from semantic_parsing_with_constrained_lm.decoding.partial_parse import PartialParse 38 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 39 | 40 | T = TypeVar("T") 41 | 42 | 43 | def get_only(items: Iterable[T]) -> T: 44 | """Returns the single value in `items`. 45 | 46 | This function raises an exception if items contains 0 or 2+ items. 47 | """ 48 | [item] = items 49 | return item 50 | 51 | 52 | @dataclass 53 | class UInt8GrammarNode: 54 | chart: EarleyChart[np.uint8, Any] 55 | lazy_pos: Callable[[], Position[np.uint8]] 56 | 57 | @cached_property 58 | def pos(self) -> Position[np.uint8]: 59 | return self.lazy_pos() 60 | 61 | @cached_property 62 | def children(self) -> "Mapping[np.uint8, UInt8GrammarNode]": 63 | # TODO: Consider using a list instead. 64 | return { 65 | terminal: UInt8GrammarNode( 66 | self.chart, 67 | lambda terminal=terminal, items=items: get_only( 68 | self.chart.advance_with_terminal(self.pos, terminal, items) 69 | ), 70 | ) 71 | for terminal, items in self.chart.advance_only_nonterminals( 72 | self.pos, unpop_terminals=False 73 | ).items() 74 | } 75 | 76 | def advance(self, seq: Sequence[np.uint8]) -> "Optional[UInt8GrammarNode]": 77 | result = self 78 | for byte in seq: 79 | result = result.children.get(byte) 80 | if result is None: 81 | break 82 | return result 83 | 84 | 85 | @dataclass 86 | class UInt8GrammarTokenizerInfo: 87 | grammar: Grammar[np.uint8, Any] 88 | tokens: Sequence[Sequence[np.uint8]] 89 | 90 | @cached_property 91 | def vocab_size(self) -> int: 92 | return len(self.tokens) 93 | 94 | @staticmethod 95 | def from_clamp_tokenizer( 96 | grammar: Grammar[np.uint8, Any], tokenizer: ClampTokenizer 97 | ) -> "UInt8GrammarTokenizerInfo": 98 | encoded_tokens = UInt8GrammarTokenizerInfo.prepare_tokens_from_clamp_tokenizer( 99 | tokenizer 100 | ) 101 | return UInt8GrammarTokenizerInfo( 102 | grammar, 103 | encoded_tokens, 104 | ) 105 | 106 | @staticmethod 107 | def prepare_tokens_from_clamp_tokenizer( 108 | tokenizer: ClampTokenizer, 109 | ) -> Sequence[Sequence[np.uint8]]: 110 | return [ 111 | np.frombuffer(tokenizer.id_to_utf8_token_map[i], dtype=np.uint8) 112 | for i in range(len(tokenizer.id_to_utf8_token_map)) 113 | ] 114 | 115 | 116 | @dataclass 117 | class UInt8EarleyPartialParse(PartialParse): 118 | grammar_node: UInt8GrammarNode 119 | info: UInt8GrammarTokenizerInfo 120 | start_pos: Position[np.uint8] 121 | _next_node_cache: Dict[int, Optional[UInt8GrammarNode]] = dataclasses.field( 122 | default_factory=dict 123 | ) 124 | 125 | def allowed_next( 126 | self, ordered_ids: Optional[torch.Tensor] = None, top_k: Optional[int] = None 127 | ) -> Tuple[Optional[torch.Tensor], bool]: 128 | # TODO: Use optimizations already in EarleyPartialParse, and others identified but not implemented: 129 | # - Only check the first N tokens from ordered_ids with `token_id_is_valid`; 130 | # for the rest, intersect grammar_node with the vocab trie 131 | # - Cross-beam pruning: https://semanticmachines.slack.com/archives/C0310DTKR6J/p1644449621986019 132 | assert ordered_ids is not None 133 | ordered_ids_list = ordered_ids.tolist() 134 | all_tokens = self.info.tokens 135 | vocab_size = self.info.vocab_size 136 | node = self.grammar_node 137 | 138 | def token_id_is_valid(i: int) -> bool: 139 | if not 0 <= i < vocab_size: 140 | return False 141 | next_node = node.advance(all_tokens[i]) 142 | self._next_node_cache[i] = next_node 143 | return next_node is not None 144 | 145 | def produce_valid_tokens() -> Iterator[int]: 146 | for i in ordered_ids_list: 147 | if token_id_is_valid(i): 148 | yield i 149 | 150 | # TODO: Add special case where grammar_node.children has no elements 151 | # (i.e. tokens_list will be empty) 152 | tokens_list = list(itertools.islice(produce_valid_tokens(), top_k)) 153 | can_end = self.grammar_node.chart.was_found( 154 | self.grammar_node.chart.grammar.root, self.start_pos, self.grammar_node.pos 155 | ) 156 | return torch.tensor(tokens_list, dtype=torch.long), can_end 157 | 158 | def append(self, token: int) -> "UInt8EarleyPartialParse": 159 | """Return a new PartialParse created by appending this token.""" 160 | if token in self._next_node_cache: 161 | node_for_result = self._next_node_cache[token] 162 | else: 163 | if not 0 <= token < self.info.vocab_size: 164 | raise ValueError("token was not in the vocabulary") 165 | node_for_result = self.grammar_node.advance(self.info.tokens[token]) 166 | 167 | if node_for_result is None: 168 | raise ValueError("invalid token to continue with") 169 | 170 | return UInt8EarleyPartialParse(node_for_result, self.info, self.start_pos) 171 | 172 | @staticmethod 173 | def initial(info: UInt8GrammarTokenizerInfo) -> "UInt8EarleyPartialParse": 174 | chart = EarleyChart(info.grammar, use_backpointers=False) 175 | start_pos = SigmaStarTriePosition[np.uint8]() 176 | chart.seek(info.grammar.root, start_pos) 177 | grammar_node = UInt8GrammarNode(chart, lambda: start_pos) 178 | return UInt8EarleyPartialParse(grammar_node, info, start_pos) 179 | -------------------------------------------------------------------------------- /run/semantic_parsing_with_constrained_lm/domains/overnight/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import json 5 | from dataclasses import dataclass 6 | from enum import Enum 7 | from typing import Callable, Dict, List 8 | 9 | from blobfile import BlobFile 10 | 11 | from semantic_parsing_with_constrained_lm.util.trie import Trie 12 | from semantic_parsing_with_constrained_lm.util.types import StrPath 13 | from semantic_parsing_with_constrained_lm.datum import Datum, FullDatum 14 | from semantic_parsing_with_constrained_lm.decoding.trie_partial_parse import TriePartialParse 15 | # from semantic_parsing_with_constrained_lm.domains.calflow.write_data import CACHE_DIR 16 | from semantic_parsing_with_constrained_lm.eval import TopKExactMatch 17 | from semantic_parsing_with_constrained_lm.tokenization import ClampTokenizer 18 | 19 | # NOTE: get rid of the catflow/dataflow dependency 20 | from appdirs import user_cache_dir 21 | CACHE_DIR = user_cache_dir("semantic_parsing_as_constrained_lm") 22 | 23 | 24 | class OutputType(str, Enum): 25 | Utterance = "utterance" 26 | MeaningRepresentation = "meaningRepresentation" 27 | 28 | 29 | @dataclass 30 | class TopKDenotationMatch(TopKExactMatch[FullDatum]): 31 | canonical_to_denotation: Dict[str, str] 32 | 33 | def _is_correct(self, pred: str, datum: FullDatum) -> bool: 34 | target = datum.canonical 35 | pred_denotation = self.canonical_to_denotation.get(pred) 36 | target_denotation = self.canonical_to_denotation.get(target, None) 37 | if pred_denotation is None and target_denotation is None: 38 | return pred == target 39 | else: 40 | return pred_denotation == target_denotation 41 | 42 | 43 | @dataclass 44 | class OvernightPieces: 45 | train_data: List[FullDatum] 46 | test_data: List[FullDatum] 47 | partial_parse_builder: Callable[[Datum], TriePartialParse] 48 | denotation_metric: TopKDenotationMatch 49 | max_length: int 50 | 51 | @staticmethod 52 | def from_dir( 53 | tokenizer: ClampTokenizer, 54 | root_dir: StrPath, 55 | domain: str, 56 | is_dev: bool, 57 | k: int, 58 | output_type: OutputType = OutputType.Utterance, 59 | simplify_logical_forms=False, 60 | prefix_with_space=False, 61 | ) -> "OvernightPieces": 62 | data_pieces = OvernightDataPieces.from_dir( 63 | root_dir, domain, is_dev, output_type, simplify_logical_forms 64 | ) 65 | decoder_pieces = OvernightDecoderPieces.create( 66 | data_pieces, tokenizer, k, prefix_with_space 67 | ) 68 | 69 | return OvernightPieces( 70 | data_pieces.train_data, 71 | data_pieces.test_data, 72 | # https://github.com/python/mypy/issues/5485 73 | decoder_pieces.partial_parse_builder, # type: ignore 74 | decoder_pieces.denotation_metric, 75 | decoder_pieces.max_length, 76 | ) 77 | 78 | 79 | @dataclass 80 | class OvernightDataPieces: 81 | train_data: List[FullDatum] 82 | test_data: List[FullDatum] 83 | target_output_to_denotation: Dict[str, str] 84 | 85 | @staticmethod 86 | def from_dir( 87 | root_dir: StrPath, 88 | domain: str, 89 | is_dev: bool, 90 | output_type: OutputType = OutputType.MeaningRepresentation, 91 | simplify_logical_forms: bool = False, 92 | ) -> "OvernightDataPieces": 93 | # TODO make this configurable? 94 | with BlobFile(str(root_dir) + f"/{domain}.canonical.json") as bf: 95 | canonical_data = json.load(bf) 96 | 97 | if output_type == OutputType.Utterance: 98 | target_output_to_denotation = { 99 | k: v["denotation"] for k, v in canonical_data.items() 100 | } 101 | datum_key = "canonical" 102 | elif output_type == OutputType.MeaningRepresentation: 103 | target_output_to_denotation = {} 104 | for program_info in canonical_data.values(): 105 | formula = program_info["formula"] 106 | if formula is None: 107 | continue 108 | if simplify_logical_forms: 109 | formula = OvernightDataPieces.simplify_lf(formula) 110 | assert formula not in target_output_to_denotation 111 | target_output_to_denotation[formula] = program_info["denotation"] 112 | datum_key = "formula" 113 | else: 114 | raise ValueError(output_type) 115 | 116 | train_data, test_data = [ 117 | [ 118 | FullDatum( 119 | dialogue_id=f"{dataset_name}-{i}", 120 | turn_part_index=None, 121 | agent_context=None, 122 | natural=d["natural"], 123 | canonical=OvernightDataPieces.simplify_lf(d[datum_key]) 124 | if simplify_logical_forms 125 | else d[datum_key], 126 | ) 127 | for i, line in enumerate( 128 | BlobFile(path, streaming=False, cache_dir=CACHE_DIR) 129 | ) 130 | for d in [json.loads(line)] 131 | ] 132 | for dataset_name, path in ( 133 | ( 134 | "train", 135 | f"{root_dir}/{domain}.train_with{'out' if is_dev else ''}_dev.jsonl", 136 | ), 137 | ("eval", f"{root_dir}/{domain}.{'dev' if is_dev else 'test'}.jsonl"), 138 | ) 139 | ] 140 | 141 | return OvernightDataPieces(train_data, test_data, target_output_to_denotation) 142 | 143 | @staticmethod 144 | def simplify_lf(lf: str) -> str: 145 | return lf.replace("edu.stanford.nlp.sempre.overnight.SimpleWorld.", "") 146 | 147 | 148 | @dataclass 149 | class OvernightDecoderPieces: 150 | data_pieces: OvernightDataPieces 151 | partial_parse_builder: Callable[[Datum], TriePartialParse] 152 | denotation_metric: TopKDenotationMatch 153 | max_length: int 154 | 155 | @staticmethod 156 | def create( 157 | data_pieces: OvernightDataPieces, 158 | tokenizer: ClampTokenizer, 159 | k: int, 160 | prefix_with_space: bool = False, 161 | ) -> "OvernightDecoderPieces": 162 | if prefix_with_space: 163 | canonical_trie = Trie( 164 | tokenizer.encode(" " + canon) 165 | for canon in data_pieces.target_output_to_denotation 166 | ) 167 | else: 168 | canonical_trie = Trie( 169 | tokenizer.encode(canon) 170 | for canon in data_pieces.target_output_to_denotation 171 | ) 172 | partial_parse_builder = lambda _: TriePartialParse(canonical_trie) 173 | 174 | denotation_metric = TopKDenotationMatch( 175 | k, data_pieces.target_output_to_denotation 176 | ) 177 | max_length = max(len(x) for x in canonical_trie) 178 | 179 | return OvernightDecoderPieces( 180 | data_pieces, partial_parse_builder, denotation_metric, max_length 181 | ) 182 | --------------------------------------------------------------------------------