├── .github └── workflows │ └── lint_code_and_run_tests.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README-LISPRESS-1.0.md ├── README-LISPRESS.md ├── README-SEMANTICS.md ├── README.md ├── SECURITY.md ├── datasets ├── README.md ├── SMCalFlow-1.0 │ ├── smcalflow-v1.0.full.data.tgz │ └── smcalflow-v1.0.inlined.data.tgz ├── SMCalFlow-2.0 │ ├── train.dataflow_dialogues.jsonl.zip │ └── valid.dataflow_dialogues.jsonl.zip └── TreeDST │ ├── dev_dst.dataflow_dialogues.jsonl.zip │ ├── test_dst.dataflow_dialogues.jsonl.zip │ └── train_dst.dataflow_dialogues.jsonl.zip ├── models ├── .gitattributes ├── multiwoz.full.checkpoint_last.pt ├── multiwoz.inline_both.checkpoint_last.pt ├── multiwoz.inline_refer.checkpoint_last.pt ├── smcalflow.full.checkpoint_last.pt └── smcalflow.inlined.checkpoint_last.pt ├── mypy.ini ├── mypy_stubs └── cached_property.pyi ├── pylintrc ├── requirements-dev.txt ├── scripts ├── evaluate.sh ├── generate_leaderboard_format.sh └── sm_model_predict.sh ├── setup.py ├── src └── dataflow │ ├── __init__.py │ ├── analysis │ ├── __init__.py │ ├── calculate_statistical_significance.py │ └── compute_data_statistics.py │ ├── core │ ├── __init__.py │ ├── constants.py │ ├── definition.py │ ├── dialogue.py │ ├── io_utils.py │ ├── linearize.py │ ├── lispress.py │ ├── prediction_report.py │ ├── program.py │ ├── program_utils.py │ ├── sexp.py │ ├── turn_prediction.py │ ├── type_inference.py │ ├── utterance_tokenizer.py │ └── utterance_utils.py │ ├── leaderboard │ ├── __init__.py │ ├── create_leaderboard_data.py │ ├── create_text_data.py │ ├── evaluate.py │ └── predict.py │ ├── multiwoz │ ├── __init__.py │ ├── belief_state_tracker_datum.py │ ├── create_belief_state_prediction_report.py │ ├── create_belief_state_tracker_data.py │ ├── create_programs.py │ ├── evaluate_belief_state_predictions.py │ ├── execute_programs.py │ ├── ontology.py │ ├── patch_trade_dialogues.py │ ├── salience_model.py │ ├── trade_dst │ │ ├── __init__.py │ │ ├── create_data.py │ │ └── mapping.pair │ └── trade_dst_utils.py │ ├── onmt_helpers │ ├── __init__.py │ ├── compute_onmt_data_stats.py │ ├── create_onmt_prediction_report.py │ ├── create_onmt_text_data.py │ ├── embeddings_to_torch.py │ └── evaluate_onmt_predictions.py │ └── py.typed └── tests ├── __init__.py ├── conftest.py ├── data └── multiwoz_2_1 │ ├── MUL1626.json │ ├── MUL2096.json │ ├── MUL2199.json │ ├── MUL2258.json │ ├── PMUL3166.json │ ├── PMUL3470.json │ ├── PMUL4478.json │ └── README.md └── test_dataflow ├── __init__.py ├── core ├── __init__.py ├── test_linearize.py ├── test_lispress.py ├── test_type_inference.py └── test_utterance_tokenizer.py └── multiwoz ├── __init__.py ├── conftest.py ├── test_cli_workflow.py ├── test_create_programs.py ├── test_evaluate_belief_state_predictions.py ├── test_execute_programs.py └── test_ontology.py /.github/workflows/lint_code_and_run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Lint Code and Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.7 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.7 20 | - name: Install Python dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install -e . 24 | pip install -r requirements-dev.txt 25 | - name: Download Spacy model 26 | run: python -m spacy download en_core_web_md-2.2.0 --direct 27 | - name: Lint code 28 | run: make 29 | env: 30 | PYTHONPATH: tests/ 31 | - name: Run tests 32 | run: python -m pytest -n auto --durations=0 tests/ 33 | env: 34 | PYTHONPATH: tests/ 35 | 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .mypy_cache/ 2 | .pytest_cache/ 3 | __pycache__/ 4 | *.egg-info/ 5 | 6 | /env/ 7 | /workspace 8 | 9 | # IntelliJ specific 10 | .idea/* 11 | .idea_modules 12 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 4 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 5 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 6 | 7 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 8 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 9 | provided by the bot. You will only need to do this once across all repos using our CLA. 10 | 11 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 12 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 13 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 14 | 15 | 16 | ## Requirements 17 | * [Python](https://www.python.org/downloads/) >= 3.7 18 | * [virtualenv](https://virtualenv.pypa.io/en/latest/) >= 16.4.3 19 | * [GNU Make](https://www.gnu.org/software/make/) >= 3.8.1 20 | 21 | ## Build 22 | * Create a new sandbox (aka venv) and install required python libraries into it 23 | ```bash 24 | virtualenv --python=python3.7 venv 25 | source venv/bin/activate 26 | 27 | # Install sm-dataflow as an editable package 28 | pip install -e . 29 | 30 | # Install additional Python packages for development 31 | pip install -r requirements-dev.txt 32 | 33 | # Download the spaCy model for tokenization 34 | python -m spacy download en_core_web_md-2.2.0 --direct 35 | 36 | # Add tests/ to PYTHONPATH 37 | export PYTHONPATH="./tests:$PYTHONPATH" 38 | ``` 39 | * Run `make test` to execute all tests. 40 | 41 | ## IntelliJ/PyCharm Setup 42 | * Setup the Python Interpreter 43 | - For PyCharm, go to `Settings -> Project Settings -> Project Interpreter`. 44 | - For IntelliJ, go to `File -> Project Structure -> Project -> Project SDK`. 45 | - Add a `Virtualenv Environment` from an `Existing environment` and set the Interpreter to `YOUR_REPO_ROOT/venv/bin/python`. 46 | - Configure `pytest` 47 | * In `Preferences -> Tools -> Python Integration Tools`, set the default test runner to `pytest`. 48 | - Setup source folders so that the `pytest` in the IDE becomes aware of the source codes. 49 | * Right click the `src` folder and choose `Mark Directory As -> Sources Root`. 50 | * Right click the `tests` folder and choose `Mark Directory As -> Test Sources Root`. 51 | 52 | ## Pull Requests 53 | We force the following checks before a pull request can be merged into master: 54 | [isort](https://pypi.org/project/isort/), 55 | [black](https://black.readthedocs.io/en/stable/), 56 | [pylint](https://www.pylint.org/), 57 | [mypy](http://mypy-lang.org/), 58 | and [pytest](https://docs.pytest.org/en/latest/). 59 | 60 | * You can run `make test` to automatically execute these checks. 61 | * You can run `make` to execute all checks except `pytest`. 62 | * To fix any formatting or import errors, you can simply run `make format`. 63 | * To fix `pylint` and `mypy` errors, you can skip prefix tasks by manually running `pylint src/ tests/` and `mypy src/ tests/`. 64 | * To fix `pytest` errors, you can skip prefix tasks by manually running `python -m pytest tests/`. 65 | You can also go to the test file and run a specific test (see htt s://www.jetbrains.com/help/pycharm/pytest.html#run-pytest-test). 66 | * For more details, see the `test` task in the [Makefile](./Makefile) 67 | and the GitHub action [lint_code_and_run_tests.yml](.github/workflows/lint_code_and_run_tests.yml). 68 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format format-check pylint mypy test 2 | 3 | all: mypy 4 | 5 | # sort imports and auto-format python code 6 | format: 7 | isort -rc src/ tests/ --multi-line=3 --trailing-comma --force-grid-wrap=0 --use-parentheses --line-width=88 -o onmt -o torch 8 | black -t py37 src/ tests/ 9 | 10 | format-check: 11 | (isort -rc src/ tests/ --check-only --multi-line=3 --trailing-comma --force-grid-wrap=0 --use-parentheses --line-width=88 -o onmt -o torch) && (black -t py37 --check src/ tests/) || (echo "run \"make format\" to format the code"; exit 1) 12 | 13 | pylint: format-check 14 | pylint -j0 src/ tests/ 15 | 16 | mypy: pylint 17 | mypy --show-error-codes src/ tests/ 18 | 19 | test: mypy $(shell find tests/ -name "*.py" -type f) 20 | python -m pytest -n auto --durations=0 tests/ 21 | -------------------------------------------------------------------------------- /README-LISPRESS-1.0.md: -------------------------------------------------------------------------------- 1 | This is an outdated description of Lispress ("Lispress 1.0"), left here to document the SMCalFLow 2 | 1.x datasets. For the more current description of Lispress, 3 | ee [this REAMDE](README-LISPRESS.md). 4 | 5 | # Lispress 6 | 7 | *Lispress* is a lisp-like serialization format for programs. 8 | It is intended to be human-readable, easy to work with in Python, and easy to 9 | tokenize and predict with a standard seq2seq model. 10 | 11 | 12 | Here is an example program in Lispress (a response to the utterance 13 | `"what is my appointment with janice kang"`): 14 | ```clojure 15 | (yield 16 | (:id 17 | (singleton 18 | (:results 19 | (FindEventWrapperWithDefaults 20 | :constraint (StructConstraint[Event] 21 | :attendees (AttendeeListHasRecipientConstraint 22 | :recipientConstraint (RecipientWithNameLike 23 | :constraint (StructConstraint[Recipient]) 24 | :name #(PersonName "janice kang"))))))))) 25 | ``` 26 | 27 | 28 | ## Syntax 29 | 30 | A Lispress program is an s-expression: either 31 | a bare symbol, or 32 | a whitespace-separated list of s-expressions, surrounded by parentheses. 33 | 34 | ### Values 35 | 36 | Value literals are represented with a hash character followed by an 37 | s-expression containing the name of the schema (i.e. type) of the data, followed by a 38 | json-encoded string literal of the data surrounded by double-quotes. 39 | For example: `#(PersonName "janice kang")`. 40 | A `Number` may omit the double-quotes, e.g. `#(Number 4)`. 41 | 42 | ### Function application 43 | 44 | The most common form in Lispress is a function applied to zero or more 45 | arguments. 46 | Function application expressions are lists, 47 | with the first element of the list denoting the function, 48 | and the remainder of the elements denoting its arguments. 49 | There are two kinds of function application: 50 | 51 | #### Named arguments 52 | If the name of a function begins with a capitalized letter (`[A-Z]`), 53 | then it accepts named arguments (and only named arguments). 54 | The name of each named argument is prefixed with a colon character, 55 | and named arguments are written after the function as alternating 56 | `:name value` pairs. 57 | Named arguments can be given in any order (when rendering, we alphabetize named arguments). 58 | 59 | For example, in 60 | ```clojure 61 | (DateAtTimeWithDefaults 62 | :date (Tomorrow) 63 | :time (NumberAM :number #(Number 10)) 64 | ``` 65 | the `DateAtTimeWithDefaults` function is a applied to two named arguments. 66 | `(Tomorrow)` is passed to the function as the `date` argument, and 67 | `(NumberAM :number #(Number 10)` is passed in as the `time` argument. 68 | `(Tomorrow)` is an example of a function applied to zero named arguments. 69 | Some functions accepting named arguments may not require all arguments to be present. 70 | You will often see the `StructConstraint[Event]` function being called without 71 | a `:subject` or an `:end`, for example. 72 | 73 | #### Positional arguments 74 | If the name of a function does not begin with a capitalized letter 75 | (i.e. it is lowercase or symbolic), then it accepts positional 76 | arguments (and only positional arguments). 77 | For example, 78 | ```clojure 79 | (?= #(String "soccer game")) 80 | ``` 81 | represents the function `?=` being 82 | applied to the single argument `#(String "soccer game")`. 83 | And `(toDays #(Number 10))` is the function `toDays` applied to the single 84 | argument `#(Number 10)`. 85 | 86 | 87 | ### Sugared `get` 88 | 89 | There is a common construct in our programs where the `get` function 90 | retrieves a field (specified by a `Path`) from a structured object. 91 | For example, 92 | ```clojure 93 | (get 94 | (refer (StructConstraint[Event])) 95 | #(Path "attendees")) 96 | ``` 97 | returns the `attendees` field of the salient `Event`. 98 | When the path is a valid identifier (i.e. contains no whitespace or special 99 | characters), the following sugared version is equivalent and preferred: 100 | ```clojure 101 | (:attendees 102 | (refer (StructConstraint[Event]))) 103 | ``` 104 | 105 | 106 | 107 | 108 | ### Variable binding with `let` 109 | 110 | To use a value more than once, it can be given a variable name using a `let` 111 | binding. 112 | A `let` binding is a list with three elements, 113 | - the keyword `let`, 114 | - a "binding" list containing alternating `variableName variableValue` pairs, and 115 | - a program body, in which variable names bound in the previous form can be 116 | referenced. 117 | 118 | For example, in the following response to `"Can you find some past events on my calendar?"`, 119 | ```clojure 120 | (let 121 | (x0 (Now)) 122 | (yield 123 | (FindEventWrapperWithDefaults 124 | :constraint (EventOnDateBeforeTime 125 | :date (:date x0) 126 | :event (StructConstraint[Event]) 127 | :time (:time x0))))) 128 | ``` 129 | the variable `x0` is assigned the value `(Now)` and then used twice in the body. 130 | Note that `(Now)` is only evaluated once. 131 | `let` bindings are an important mechanism to reuse the result of a 132 | side-effecting computation. 133 | For example, depending on the implementation of `Now`, the 134 | following program may be referencing different values in the `:date` and `:time` fields: 135 | ```clojure 136 | (FindEventWrapperWithDefaults 137 | :constraint (EventOnDateBeforeTime 138 | :date (:date (Now)) 139 | :event (StructConstraint[Event]) 140 | :time (:time (Now))))) 141 | ``` 142 | 143 | ### Performing multiple actions in a turn with `do` 144 | 145 | Two or more statements can be sequenced using the `do` keyword. 146 | Each statement in a `do` form is fully interpreted and executed before any following 147 | statements are. 148 | In 149 | ```clojure 150 | (do 151 | (ConfirmAndReturnAction) 152 | (yield 153 | (:start 154 | (FindNumNextEvent 155 | :constraint (StructConstraint[Event]) 156 | :number #(Number 1))))) 157 | ``` 158 | for example, `ConfirmAndReturnAction` is guaranteed to execute before `FindNumNextEvent`. 159 | 160 | 161 | 162 | 163 | ## Code 164 | 165 | Code for parsing and rendering Lispress is in the `dataflow.core.lispress` 166 | package. 167 | 168 | `parse_lispress` converts a string into a `Lispress` object, which is a nested 169 | list-of-lists with `str`s as leaves. 170 | `render_compact` renders `Lispress` on a single line (used in our `jsonl` data 171 | files), and `render_pretty` renders with indentation, which is easier to read. 172 | 173 | `lispress_to_program` and `program_to_lispress` convert to and from a `Program` object, 174 | which is closer to a computation DAG (rather than an abstract syntax tree), and 175 | is sometimes more convenient to work with. 176 | -------------------------------------------------------------------------------- /README-LISPRESS.md: -------------------------------------------------------------------------------- 1 | # Lispress 2 | 3 | *Lispress* is a lisp-like serialization format for programs. 4 | It is intended to be human-readable, easy to work with in Python, and easy to 5 | tokenize and predict with a standard seq2seq model. An older version, Lispress 1.0, 6 | is described in [this README](README-LISPRESS-1.0.md). The current code is backwards 7 | compatible with Lispress 1.0 programs. 8 | 9 | 10 | Here is an example program in Lispress (a response to the utterance 11 | `"what is my appointment with janice kang"`): 12 | ```clojure 13 | (Yield 14 | (Event.id 15 | (singleton 16 | (QueryEventResponse.results 17 | (FindEventWrapperWithDefaults 18 | (Event.attendees_? 19 | (AttendeeListHasRecipientConstraint 20 | (RecipientWithNameLike 21 | (^(Recipient) EmptyStructConstraint) 22 | (PersonName.apply "janice kang"))))))))) 23 | ``` 24 | 25 | 26 | ## Syntax 27 | 28 | A Lispress program is an s-expression: either 29 | a bare symbol, or 30 | a whitespace-separated list of s-expressions, surrounded by parentheses. There is a little 31 | bit of special syntax: 32 | * Strings surrounded by double-quotes (`"`) are treated parsed as a single are symbol 33 | (including the quotes), with standard JSON escaping for strings. For example, 34 | ```clojure 35 | (MyFunc "this is a (quoted) string with a \" in it") 36 | ``` 37 | will pass the symbol `"this is a (quoted) string with a \" in it"` to `MyFunc`. 38 | Note that when converting to a Program, we trim the whitespace from either side of a 39 | string, so `(MyFunc " a ")` and `(MyFunc "a")` are the same program. 40 | * The meta character (`^`) 41 | ([borrowed from Clojure](https://clojure.org/reference/metadata)) 42 | can be used for type ascriptions and type arguments. For example, 43 | ```clojure 44 | ^Number 1 45 | ``` 46 | would be written as `1: Number` in Scala. A list marked by the meta character 47 | in the first argument of an s-expression is interpreted as a list of type arguments. 48 | For example, 49 | ```clojure 50 | (^(Number) MyFunc 1) 51 | ``` 52 | would be written as `MyFunc[Number](1)` in Scala or `MyFunc(1)` in Swift and Rust. 53 | * (Deprecated) The reader macro character (`#`), 54 | [borrowed from Common Lisp](https://gist.github.com/chaitanyagupta/9324402) 55 | marks literal values. 56 | For example, `#(PersonName "John")` marks a value of type `PersonName` with 57 | content `"John"`. Reader macros are no longer in Lispress 2.0. Instead, 58 | standard literals like booleans, longs, numbers, and strings, can be written directly, 59 | while wrapper types (like `PersonName`) feature an explicit call to a constructor 60 | like `PersonName.apply`. The current code will interpret Lispress 1.0 61 | `Number`s, `Boolean`s, and `String`s as their bare equivalents, so `#(String "foo")` and `"foo"` 62 | will be interpreted as the same program. Similarly, `#(Number 1)` and `1` will 63 | be interpreted as the same program, and `#(Boolean true)` and `true` are the same 64 | program. 65 | * Literals of type Long are written as an integer literal followed by an `L` (e.g. `12L`) 66 | as in Java/Scala. 67 | 68 | ### Function application 69 | 70 | The most common form in Lispress is a function applied to zero or more 71 | arguments. 72 | Function application expressions are lists, 73 | with the first element of the list denoting the function, 74 | and the remainder of the elements denoting its arguments. 75 | We follow Common Lisp and Clojure in using `:` to denote named arguments. For example, 76 | `(MyFunc :foo 1)` would be `MyFunc(foo = 1)` in Scala or Python. At present, functions 77 | must either be entirely positional or entirely named, and only functions with an 78 | uppercase letter for the first character may take named arguments. 79 | 80 | ### (Deprecated) Sugared `get` 81 | 82 | There is a common construct in the SMCalFLow 1.x dataset where the `get` function 83 | retrieves a field (specified by a `Path`) from a structured object. 84 | For example, 85 | ```clojure 86 | (get 87 | (refer (StructConstraint[Event])) 88 | #(Path "attendees")) 89 | ``` 90 | returns the `attendees` field of the salient `Event`. 91 | For backwards compatibility with Lispress 1.0, the parser will accept 92 | the following equivalent form. 93 | ```clojure 94 | (:attendees (refer (StructConstraint[Event]))) 95 | ``` 96 | 97 | In updated Lispress, accessor functions contain the name of the type they access: 98 | ```clojure 99 | (Event.attendees (refer (^(Event) StructConstraint))) 100 | ``` 101 | 102 | 103 | 104 | 105 | ### Variable binding with `let` 106 | 107 | To use a value more than once, it can be given a variable name using a `let` 108 | binding. 109 | A `let` binding is a list with three elements, 110 | - the keyword `let`, 111 | - a "binding" list containing alternating `variableName variableValue` pairs, and 112 | - a program body, in which variable names bound in the previous form can be 113 | referenced. 114 | 115 | For example, in the following response to `"Can you find some past events on my calendar?"`, 116 | ```clojure 117 | (let 118 | (x0 (Now)) 119 | (Yield 120 | (FindEventWrapperWithDefaults 121 | (EventOnDateBeforeTime 122 | (DateTime.date x0) 123 | (^(Event) EmptyStructConstraint) 124 | (DateTime.time x0))))) 125 | ``` 126 | the variable `x0` is assigned the value `(Now)` and then used twice in the body. 127 | Note that `(Now)` is only evaluated once. 128 | `let` bindings are an important mechanism to reuse the result of a 129 | side-effecting computation. 130 | For example, depending on the implementation of `Now`, the 131 | following program may be produce different values in the `:date` and `:time` fields: 132 | ```clojure 133 | (FindEventWrapperWithDefaults 134 | (EventOnDateBeforeTime 135 | (DateTime.date (Now))) 136 | (^(Event) EmptyStructConstraint) 137 | (DateTime.time (Now))))) 138 | ``` 139 | 140 | ### Performing multiple actions in a turn with `do` 141 | 142 | Two or more statements can be sequenced using the `do` keyword. 143 | Each statement in a `do` form is fully interpreted and executed before any following 144 | statements are. 145 | In 146 | ```clojure 147 | (do 148 | (ConfirmAndReturnAction) 149 | (yield 150 | (Event.start 151 | (FindNumNextEvent 152 | (^(Event) StructConstraint) 153 | 1L)))) 154 | ``` 155 | for example, `ConfirmAndReturnAction` is guaranteed to execute before `FindNumNextEvent`. 156 | 157 | 158 | 159 | 160 | ## Code 161 | 162 | Code for parsing and rendering Lispress is in the `dataflow.core.lispress` 163 | package. 164 | 165 | `parse_lispress` converts a string into a `Lispress` object, which is a nested 166 | list-of-lists with `str`s as leaves. 167 | `render_compact` renders `Lispress` on a single line (used in our `jsonl` data 168 | files), and `render_pretty` renders with indentation, which is easier to read. 169 | 170 | `lispress_to_program` and `program_to_lispress` convert to and from a `Program` object, 171 | which is closer to a computation DAG (rather than an abstract syntax tree), and 172 | is sometimes more convenient to work with. 173 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | This directory contains the conversational semantic parsing datasets we used for the experiments of the following papers: 4 | 5 | ```bibtex 6 | @article{SMDataflow2020, 7 | author = {{Semantic Machines} and Andreas, Jacob and Bufe, John and Burkett, David and Chen, Charles and Clausman, Josh and Crawford, Jean and Crim, Kate and DeLoach, Jordan and Dorner, Leah and Eisner, Jason and Fang, Hao and Guo, Alan and Hall, David and Hayes, Kristin and Hill, Kellie and Ho, Diana and Iwaszuk, Wendy and Jha, Smriti and Klein, Dan and Krishnamurthy, Jayant and Lanman, Theo and Liang, Percy and Lin, Christopher H. and Lintsbakh, Ilya and McGovern, Andy and Nisnevich, Aleksandr and Pauls, Adam and Petters, Dmitrij and Read, Brent and Roth, Dan and Roy, Subhro and Rusak, Jesse and Short, Beth and Slomin, Div and Snyder, Ben and Striplin, Stephon and Su, Yu and Tellman, Zachary and Thomson, Sam and Vorobev, Andrei and Witoszko, Izabela and Wolfe, Jason and Wray, Abby and Zhang, Yuchen and Zotov, Alexander}, 8 | title = {Task-Oriented Dialogue as Dataflow Synthesis}, 9 | journal = {Transactions of the Association for Computational Linguistics}, 10 | volume = {8}, 11 | pages = {556--571}, 12 | year = {2020}, 13 | month = sep, 14 | url = {https://doi.org/10.1162/tacl_a_00333}, 15 | } 16 | 17 | @inproceedings{SMValueAgnosticParsing2021, 18 | author = {Platanios, Emmanouil Antonios and Pauls, Adam and Roy, Subhro and Zhang, Yuchen and Kyte, Alex and Guo, Alan and Thomson, Sam and Krishnamurthy, Jayant and Wolfe, Jason and Andreas, Jacob and Klein, Dan}, 19 | title = {Value-Agnostic Conversational Semantic Parsing}, 20 | booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics}, 21 | month = aug, 22 | year = {2021}, 23 | address = {Online}, 24 | publisher = {Association for Computational Linguistics}, 25 | } 26 | ``` 27 | 28 | There are three datasets: 29 | 30 | - **SMCalFlow-1.0:** This is the dataset released with the [**Task-Oriented Dialogue as Dataflow Synthesis** (TACL 2020)](https://www.mitpressjournals.org/doi/full/10.1162/tacl_a_00333) paper. 31 | - **SMCalFlow-2.0:** This is an updated version of the dataset released with the [**Task-Oriented Dialogue as Dataflow Synthesis** (TACL 2020)](https://www.mitpressjournals.org/doi/full/10.1162/tacl_a_00333) paper, which removed a very small number of incorrectly annotated examples, dropped argument names for positional arguments (so that the programs are shorter), and added inferred type arguments for type-parameterized functions that were missing in the original SMCalFlow data. 32 | - **TreeDST:** This is a modified version of the [TreeDST dataset]([apple/ml-tree-dst (github.com)](https://github.com/apple/ml-tree-dst)) which has been converted to the Lispress representation used for SMCalFlow 2.0, and transformed to make use of the `refer` and `revise` meta-computation operators. The transformation is described in the appendix of the paper referenced above. 33 | 34 | Furthermore, compared to the original release of the SMCalFlow dataset, these two datasets also provide programs which have been fully annotated with argument names for all function arguments and types for all expressions after running a Hindley-Milner based type inference algorithm (also described in the aforementioned paper). These programs are included in the new `fully_typed_lispress` field in the JSON objects that correspond to dialogue turns. It is not recommended to use these programs directly with simple Seq2Seq baselines because they are very verbose and the information they additional information they contain can be derived directly from the `lispress` programs by running type inference. That is also why the `lispress` programs are the ones used by the official evaluation script in SMCalFlow leaderboard. 35 | 36 | Note that the version uploaded before June 28, 2021 contained some minor errors. You should 37 | re-download the datasets if you downloaded the datasets before that date. 38 | -------------------------------------------------------------------------------- /datasets/SMCalFlow-1.0/smcalflow-v1.0.full.data.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/SMCalFlow-1.0/smcalflow-v1.0.full.data.tgz -------------------------------------------------------------------------------- /datasets/SMCalFlow-1.0/smcalflow-v1.0.inlined.data.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/SMCalFlow-1.0/smcalflow-v1.0.inlined.data.tgz -------------------------------------------------------------------------------- /datasets/SMCalFlow-2.0/train.dataflow_dialogues.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/SMCalFlow-2.0/train.dataflow_dialogues.jsonl.zip -------------------------------------------------------------------------------- /datasets/SMCalFlow-2.0/valid.dataflow_dialogues.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/SMCalFlow-2.0/valid.dataflow_dialogues.jsonl.zip -------------------------------------------------------------------------------- /datasets/TreeDST/dev_dst.dataflow_dialogues.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/TreeDST/dev_dst.dataflow_dialogues.jsonl.zip -------------------------------------------------------------------------------- /datasets/TreeDST/test_dst.dataflow_dialogues.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/TreeDST/test_dst.dataflow_dialogues.jsonl.zip -------------------------------------------------------------------------------- /datasets/TreeDST/train_dst.dataflow_dialogues.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/datasets/TreeDST/train_dst.dataflow_dialogues.jsonl.zip -------------------------------------------------------------------------------- /models/.gitattributes: -------------------------------------------------------------------------------- 1 | multiwoz.full.checkpoint_last.pt filter=lfs diff=lfs merge=lfs -text 2 | multiwoz.inline_both.checkpoint_last.pt filter=lfs diff=lfs merge=lfs -text 3 | multiwoz.inline_refer.checkpoint_last.pt filter=lfs diff=lfs merge=lfs -text 4 | smcalflow.full.checkpoint_last.pt filter=lfs diff=lfs merge=lfs -text 5 | smcalflow.inlined.checkpoint_last.pt filter=lfs diff=lfs merge=lfs -text 6 | -------------------------------------------------------------------------------- /models/multiwoz.full.checkpoint_last.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:78bfc16db92af83c8b489bac236366eca4e2ac40d8b8ef5794e5b021fccb4907 3 | size 130057122 4 | -------------------------------------------------------------------------------- /models/multiwoz.inline_both.checkpoint_last.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e515ce5d1e7fd274e682044db02f4780d7eb9de1fff34472f38b719b3c92f731 3 | size 111753901 4 | -------------------------------------------------------------------------------- /models/multiwoz.inline_refer.checkpoint_last.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:52c18edde879530d9cde12e7a58b10287f3042c98570ce81a4a6b2e34746bf78 3 | size 129128738 4 | -------------------------------------------------------------------------------- /models/smcalflow.full.checkpoint_last.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c0f2ce2e6f33c2111b99dd3f56a0e14c5f7d7c4b6a35f2171f84a159ec0fc38 3 | size 368655213 4 | -------------------------------------------------------------------------------- /models/smcalflow.inlined.checkpoint_last.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:39cf8a0bf9cef3f8aa9072c96f288980e4daa0e341ddc9d47dcf0513830bbd87 3 | size 470150160 4 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.7 3 | incremental = False 4 | strict_optional = False 5 | mypy_path=./src/:./tests/:./mypy_stubs 6 | 7 | [mypy-pytest.*,_pytest.*,jsons.*,more_itertools.*,tqdm.*,glob2.*,sexpdata.*] 8 | ignore_missing_imports = True 9 | 10 | [mypy-spacy.*] 11 | ignore_missing_imports = True 12 | 13 | [mypy-matplotlib.*,numpy.*,pandas.*,sklearn.*,statsmodels.*] 14 | ignore_missing_imports = True 15 | 16 | [mypy-onmt.*,torch.*] 17 | ignore_missing_imports = True 18 | -------------------------------------------------------------------------------- /mypy_stubs/cached_property.pyi: -------------------------------------------------------------------------------- 1 | cached_property = property 2 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==19.10b0 2 | pylint==2.4.4 3 | mypy==0.761 4 | pytest==4.6.3 5 | pytest-xdist==1.31.0 6 | cached-property==1.5.1 7 | -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | # 5 | # Semantic Machines (TM) software. 6 | # 7 | # This script takes as input a gold answers file and a prediction file, and outputs the accuracy. 8 | 9 | gold_file=$1 10 | prediction_file=$2 11 | 12 | pip install --user . 13 | export PATH=$PATH:.local/bin 14 | 15 | python -m dataflow.leaderboard.evaluate --predictions_jsonl ${prediction_file} --gold_jsonl ${gold_file} --scores_json scores.json 16 | 17 | rm -rf output .local .cache 18 | -------------------------------------------------------------------------------- /scripts/generate_leaderboard_format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | # 5 | # Semantic Machines (TM) software. 6 | # 7 | # This script converts the SMCalFlow native format to the format used by the leaderboard. 8 | # Only runs for valid and test. 9 | 10 | data_folder=$1 11 | output_folder=$2 12 | 13 | for subset in "valid" "test"; do 14 | python -m dataflow.leaderboard.create_leaderboard_data \ 15 | --dialogues_jsonl ${data_folder}/${subset}.dataflow_dialogues.jsonl \ 16 | --contextualized_turns_file ${output_folder}/${subset}.leaderboard_dialogues.jsonl \ 17 | --turn_answers_file ${output_folder}/${subset}.answers.jsonl \ 18 | --dialogue_id_prefix ${subset} 19 | done 20 | -------------------------------------------------------------------------------- /scripts/sm_model_predict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | # 5 | # Semantic Machines (TM) software. 6 | # 7 | # This script takes as input a model path and validation set, and generates predictions on the validation set in 8 | # a file named `predictions.jsonl`. 9 | 10 | model_path=$1 11 | data_path=$2 12 | 13 | pip install --user . 14 | pip install --user OpenNMT-py==1.0.0 15 | export PATH=$PATH:.local/bin 16 | 17 | # Prepare text data for the OpenNMT toolkit. 18 | onmt_text_data_dir="output/onmt_text_data" 19 | mkdir -p "${onmt_text_data_dir}" 20 | subset="valid" 21 | python -m dataflow.leaderboard.create_text_data \ 22 | --dialogues_jsonl ${data_path} \ 23 | --num_context_turns 2 \ 24 | --include_program \ 25 | --include_described_entities \ 26 | --onmt_text_data_outbase ${onmt_text_data_dir}/${subset} 27 | 28 | # Make predictions using a trained OpenNMT model. You need to replace the `checkpoint_last.pt` in the following script 29 | #with the final model you get from the previous step. 30 | onmt_translate_outdir="output/onmt_translate_output" 31 | mkdir -p "${onmt_translate_outdir}" 32 | 33 | nbest=5 34 | tgt_max_ntokens=1000 35 | 36 | # predict programs using a trained OpenNMT model 37 | onmt_translate \ 38 | --model ${model_path} \ 39 | --max_length ${tgt_max_ntokens} \ 40 | --src ${onmt_text_data_dir}/valid.src_tok \ 41 | --replace_unk \ 42 | --n_best ${nbest} \ 43 | --batch_size 8 \ 44 | --beam_size 10 \ 45 | --gpu 0 \ 46 | --report_time \ 47 | --output ${onmt_translate_outdir}/valid.nbest 48 | 49 | # create the prediction report 50 | python -m dataflow.leaderboard.predict \ 51 | --datum_id_jsonl ${onmt_text_data_dir}/valid.datum_id \ 52 | --src_txt ${onmt_text_data_dir}/valid.src_tok \ 53 | --ref_txt ${onmt_text_data_dir}/valid.tgt \ 54 | --nbest_txt ${onmt_translate_outdir}/valid.nbest \ 55 | --nbest ${nbest} 56 | 57 | rm -rf output .local .cache 58 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | from setuptools import find_packages, setup 5 | 6 | setup( 7 | name="sm-dataflow", 8 | version="0.1", 9 | author="Semantic Machines (TM)", 10 | description="Task-Oriented Dialogue as Dataflow Synthesis", 11 | license="MIT", 12 | packages=find_packages("src"), 13 | package_dir={"": "src"}, 14 | package_data={"dataflow": ["py.typed"]}, 15 | zip_safe=False, 16 | install_requires=[ 17 | "jsons==0.10.1", 18 | "more-itertools==8.2.0", 19 | "sexpdata==0.0.3", 20 | "pandas==1.0.0", 21 | "spacy==2.2.1", 22 | "statsmodels==0.11.1", 23 | "cached-property==1.5.1", 24 | ], 25 | extra_requires={ 26 | "OpenNMT-py": ["OpenNMT-py==1.0.0", "pytorch>=1.2.0,<=1.4.0"] 27 | }, 28 | python_requires=">=3.7", 29 | ) 30 | -------------------------------------------------------------------------------- /src/dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/dataflow/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/src/dataflow/analysis/__init__.py -------------------------------------------------------------------------------- /src/dataflow/analysis/calculate_statistical_significance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Calculates statistical significance for predictions from two experiments. 7 | """ 8 | import argparse 9 | import csv 10 | import json 11 | from typing import Callable, List, Optional, Tuple 12 | 13 | import numpy as np 14 | import pandas as pd 15 | from statsmodels.stats.contingency_tables import mcnemar 16 | 17 | from dataflow.core.dialogue import TurnId 18 | from dataflow.core.io_utils import load_jsonl_file 19 | from dataflow.onmt_helpers.evaluate_onmt_predictions import evaluate_dialogue 20 | 21 | 22 | def get_report_dataframes( 23 | exp0_prediction_report_df: pd.DataFrame, exp1_prediction_report_df: pd.DataFrame, 24 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 25 | """Returns the turn-level and dialogue-level report dataframes.""" 26 | exp0_prediction_report_df.set_index( 27 | ["dialogueId", "turnIndex"], inplace=True, drop=True 28 | ) 29 | exp1_prediction_report_df.set_index( 30 | ["dialogueId", "turnIndex"], inplace=True, drop=True 31 | ) 32 | turn_report_df = exp0_prediction_report_df.join( 33 | exp1_prediction_report_df.loc[:, ["isCorrect"]], 34 | how="outer", 35 | lsuffix="_0", 36 | rsuffix="_1", 37 | ) 38 | assert not turn_report_df.isnull().any().any() 39 | assert ( 40 | len(turn_report_df) 41 | == len(exp0_prediction_report_df) 42 | == len(exp1_prediction_report_df) 43 | ) 44 | 45 | rows = [] 46 | for dialogue_id, df_for_dialogue in turn_report_df.groupby("dialogueId"): 47 | dialogue_scores0 = evaluate_dialogue( 48 | turns=[ 49 | (turn_index, row.get("isCorrect_0")) 50 | for (_, turn_index), row in df_for_dialogue.iterrows() 51 | ] 52 | ) 53 | dialogue_scores1 = evaluate_dialogue( 54 | turns=[ 55 | (turn_index, row.get("isCorrect_1")) 56 | for (_, turn_index), row in df_for_dialogue.iterrows() 57 | ] 58 | ) 59 | 60 | rows.append( 61 | { 62 | "dialogueId": dialogue_id, 63 | "isCorrect_0": dialogue_scores0.num_correct_dialogues > 0, 64 | "isCorrect_1": dialogue_scores1.num_correct_dialogues > 0, 65 | "prefix_0": dialogue_scores0.num_turns_before_first_error, 66 | "prefix_1": dialogue_scores1.num_turns_before_first_error, 67 | } 68 | ) 69 | dialogue_report_df = pd.DataFrame(rows) 70 | return turn_report_df, dialogue_report_df 71 | 72 | 73 | def run_mcnemar_test(report_df: pd.DataFrame) -> Tuple[float, float]: 74 | mask_correct_0 = report_df.loc[:, "isCorrect_0"] 75 | mask_correct_1 = report_df.loc[:, "isCorrect_1"] 76 | contingency_table = ( 77 | ( 78 | (mask_correct_0 & mask_correct_1).sum(), 79 | (mask_correct_0 & ~mask_correct_1).sum(), 80 | ), 81 | ( 82 | (~mask_correct_0 & mask_correct_1).sum(), 83 | (~mask_correct_0 & ~mask_correct_1).sum(), 84 | ), 85 | ) 86 | result = mcnemar(contingency_table) 87 | 88 | return result.statistic, result.pvalue 89 | 90 | 91 | def run_paired_permutation_test( 92 | xs: List[int], 93 | ys: List[int], 94 | samples: int = 10000, 95 | statistic: Callable[[List[int]], float] = np.mean, # type: ignore 96 | ) -> float: 97 | """Runs the two-sample permutation test to check whether the paired data xs and ys are from the same distribution (null hypothesis). 98 | 99 | Args: 100 | xs: the data from distribution F1 101 | ys: the data from distribution F2 102 | samples: the number of samples for the Monte Carlo sampling 103 | statistic: the statistic to be used for the test (default is the mean) 104 | 105 | Returns: 106 | the p-value of the null hypothesis (two-tailed) 107 | """ 108 | 109 | def effect(xx: List[int], yy: List[int]) -> float: 110 | return np.abs(statistic(xx) - statistic(yy)) 111 | 112 | n, k = len(xs), 0 113 | diff = effect(xs, ys) # observed difference 114 | for _ in range(samples): # for each random sample 115 | swaps = np.random.randint(0, 2, n).astype(bool) # flip n coins 116 | k += diff <= effect( 117 | np.select([swaps, ~swaps], [xs, ys]), # swap elements accordingly 118 | np.select([~swaps, swaps], [xs, ys]), 119 | ) 120 | 121 | # fraction of random samples that achieved at least the observed difference 122 | return k / float(samples) 123 | 124 | 125 | def main( 126 | exp0_prediction_report_tsv: str, 127 | exp1_prediction_report_tsv: str, 128 | datum_ids_jsonl: Optional[str], 129 | scores_json: str, 130 | ) -> None: 131 | """Loads the two prediction report files and calculates statistical significance. 132 | 133 | For the turn-level and dialogue-level accuracy, we use the McNemar test. 134 | For the dialogue-level prefix length (i.e., the number of turns before the first error), we use the two-sample permutation test. 135 | 136 | If `datum_ids_jsonl` is given, we only use the subset of turns specified in the file. In this case, only turn-level 137 | metrics are used since it doesn't make sense to compute dialogue-level metrics with only a subset of turns. 138 | """ 139 | exp0_prediction_report_df = pd.read_csv( 140 | exp0_prediction_report_tsv, 141 | sep="\t", 142 | encoding="utf-8", 143 | quoting=csv.QUOTE_ALL, 144 | na_values=None, 145 | keep_default_na=False, 146 | ) 147 | assert not exp0_prediction_report_df.isnull().any().any() 148 | 149 | exp1_prediction_report_df = pd.read_csv( 150 | exp1_prediction_report_tsv, 151 | sep="\t", 152 | encoding="utf-8", 153 | quoting=csv.QUOTE_ALL, 154 | na_values=None, 155 | keep_default_na=False, 156 | ) 157 | assert not exp1_prediction_report_df.isnull().any().any() 158 | 159 | turn_report_df, dialogue_report_df = get_report_dataframes( 160 | exp0_prediction_report_df=exp0_prediction_report_df, 161 | exp1_prediction_report_df=exp1_prediction_report_df, 162 | ) 163 | 164 | if not datum_ids_jsonl: 165 | turn_statistic, turn_pvalue = run_mcnemar_test(turn_report_df) 166 | dialogue_statistic, dialogue_pvalue = run_mcnemar_test(dialogue_report_df) 167 | prefix_pvalue = run_paired_permutation_test( 168 | xs=dialogue_report_df.loc[:, "prefix_0"].tolist(), 169 | ys=dialogue_report_df.loc[:, "prefix_1"].tolist(), 170 | ) 171 | 172 | with open(scores_json, "w") as fp: 173 | fp.write( 174 | json.dumps( 175 | { 176 | "turn": {"statistic": turn_statistic, "pvalue": turn_pvalue}, 177 | "dialogue": { 178 | "statistic": dialogue_statistic, 179 | "pvalue": dialogue_pvalue, 180 | }, 181 | "prefix": {"pvalue": prefix_pvalue}, 182 | }, 183 | indent=2, 184 | ) 185 | ) 186 | fp.write("\n") 187 | 188 | else: 189 | datum_ids = set( 190 | load_jsonl_file(data_jsonl=datum_ids_jsonl, cls=TurnId, verbose=False) 191 | ) 192 | mask_datum_id = [ 193 | TurnId(dialogue_id=dialogue_id, turn_index=turn_index) in datum_ids 194 | for (dialogue_id, turn_index), row in exp1_prediction_report_df.iterrows() 195 | ] 196 | turn_report_df = turn_report_df.loc[mask_datum_id] 197 | # NOTE: We only compute turn-level statistics since it doesn't make sense to compute dialogue-level metrics 198 | # with only a subset of turns. 199 | turn_statistic, turn_pvalue = run_mcnemar_test(turn_report_df) 200 | 201 | with open(scores_json, "w") as fp: 202 | fp.write( 203 | json.dumps( 204 | {"turn": {"statistic": turn_statistic, "pvalue": turn_pvalue}}, 205 | indent=2, 206 | ) 207 | ) 208 | fp.write("\n") 209 | 210 | 211 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 212 | argument_parser.add_argument( 213 | "--exp0_prediction_report_tsv", 214 | help="the prediction report tsv file for one experiment exp0", 215 | ) 216 | argument_parser.add_argument( 217 | "--exp1_prediction_report_tsv", 218 | help="the prediction report tsv file for the other experiment exp1", 219 | ) 220 | argument_parser.add_argument( 221 | "--datum_ids_jsonl", default=None, help="if set, only evaluate on these turns", 222 | ) 223 | argument_parser.add_argument("--scores_json", help="output scores json file") 224 | 225 | 226 | if __name__ == "__main__": 227 | cmdline_parser = argparse.ArgumentParser( 228 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 229 | ) 230 | add_arguments(cmdline_parser) 231 | args = cmdline_parser.parse_args() 232 | 233 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 234 | main( 235 | exp0_prediction_report_tsv=args.exp0_prediction_report_tsv, 236 | exp1_prediction_report_tsv=args.exp1_prediction_report_tsv, 237 | datum_ids_jsonl=args.datum_ids_jsonl, 238 | scores_json=args.scores_json, 239 | ) 240 | -------------------------------------------------------------------------------- /src/dataflow/analysis/compute_data_statistics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Compute statistics for the dataflow dialogues. 7 | """ 8 | 9 | import argparse 10 | import dataclasses 11 | import json 12 | import os 13 | import re 14 | from dataclasses import dataclass 15 | from typing import Dict, List, Tuple 16 | 17 | import numpy as np 18 | import pandas as pd 19 | 20 | from dataflow.core.dialogue import Dialogue, Turn, TurnId 21 | from dataflow.core.io_utils import load_jsonl_file, save_jsonl_file 22 | from dataflow.core.program_utils import DataflowFn 23 | 24 | 25 | @dataclass(frozen=True) 26 | class BasicStatistics: 27 | num_dialogues: int 28 | num_turns: int 29 | num_kept_turns: int 30 | num_skipped_turns: int 31 | num_refer_turns: int 32 | num_revise_turns: int 33 | 34 | 35 | def is_refer_turn(turn: Turn) -> bool: 36 | if re.search(rf"\({DataflowFn.Refer.value} ", turn.lispress): 37 | return True 38 | return False 39 | 40 | 41 | def is_revise_turn(turn: Turn) -> bool: 42 | if re.search(rf"\({DataflowFn.Revise.value} ", turn.lispress): 43 | return True 44 | return False 45 | 46 | 47 | def build_dialogue_report( 48 | dataflow_dialogues: List[Dialogue], 49 | ) -> Tuple[pd.DataFrame, List[TurnId], List[TurnId]]: 50 | refer_turn_ids = [] 51 | revise_turn_ids = [] 52 | report_rows = [] 53 | 54 | for dialogue in dataflow_dialogues: 55 | num_turns = len(dialogue.turns) 56 | num_kept_turns = 0 57 | num_skipped_turns = 0 58 | num_refer_turns = 0 59 | num_revise_turns = 0 60 | for turn in dialogue.turns: 61 | if turn.skip: 62 | num_skipped_turns += 1 63 | continue 64 | 65 | num_kept_turns += 1 66 | if is_refer_turn(turn): 67 | num_refer_turns += 1 68 | refer_turn_ids.append( 69 | TurnId(dialogue_id=dialogue.dialogue_id, turn_index=turn.turn_index) 70 | ) 71 | if is_revise_turn(turn): 72 | num_revise_turns += 1 73 | revise_turn_ids.append( 74 | TurnId(dialogue_id=dialogue.dialogue_id, turn_index=turn.turn_index) 75 | ) 76 | 77 | report_rows.append( 78 | { 79 | "dialogueId": dialogue.dialogue_id, 80 | "numTurns": num_turns, 81 | "numKeptTurns": num_kept_turns, 82 | "numSkippedTurns": num_skipped_turns, 83 | "numReferTurns": num_refer_turns, 84 | "numReviseTurns": num_revise_turns, 85 | } 86 | ) 87 | 88 | report_df = pd.DataFrame(report_rows) 89 | return report_df, refer_turn_ids, revise_turn_ids 90 | 91 | 92 | def compute_stats( 93 | dialogue_report_df: pd.DataFrame, 94 | ) -> Tuple[BasicStatistics, Dict[str, List[float]]]: 95 | basic_stats = BasicStatistics( 96 | num_dialogues=len(dialogue_report_df), 97 | num_turns=int(dialogue_report_df.loc[:, "numTurns"].sum()), 98 | num_kept_turns=int(dialogue_report_df.loc[:, "numKeptTurns"].sum()), 99 | num_skipped_turns=int(dialogue_report_df.loc[:, "numSkippedTurns"].sum()), 100 | num_refer_turns=int(dialogue_report_df.loc[:, "numReferTurns"].sum()), 101 | num_revise_turns=int(dialogue_report_df.loc[:, "numReviseTurns"].sum()), 102 | ) 103 | percentiles = list(range(0, 101, 10)) 104 | percentile_stats = { 105 | field: list( 106 | np.percentile(dialogue_report_df.loc[:, field].tolist(), percentiles) 107 | ) 108 | for field in [ 109 | "numTurns", 110 | "numKeptTurns", 111 | "numSkippedTurns", 112 | "numReferTurns", 113 | "numReviseTurns", 114 | ] 115 | } 116 | 117 | return basic_stats, percentile_stats 118 | 119 | 120 | def main(dataflow_dialogues_dir: str, subsets: List[str], outdir: str): 121 | if not os.path.exists(outdir): 122 | os.mkdir(outdir) 123 | dialogue_report_dfs = [] 124 | for subset in subsets: 125 | dataflow_dialogues = list( 126 | load_jsonl_file( 127 | data_jsonl=os.path.join( 128 | dataflow_dialogues_dir, f"{subset}.dataflow_dialogues.jsonl" 129 | ), 130 | cls=Dialogue, 131 | unit=" dialogues", 132 | ) 133 | ) 134 | 135 | dialogue_report_df, refer_turn_ids, revise_turn_ids = build_dialogue_report( 136 | dataflow_dialogues 137 | ) 138 | dialogue_report_dfs.append(dialogue_report_df) 139 | 140 | save_jsonl_file( 141 | data=refer_turn_ids, 142 | data_jsonl=os.path.join(outdir, f"{subset}.refer_turn_ids.jsonl"), 143 | ) 144 | save_jsonl_file( 145 | data=revise_turn_ids, 146 | data_jsonl=os.path.join(outdir, f"{subset}.revise_turn_ids.jsonl"), 147 | ) 148 | 149 | basic_stats, percentile_stats = compute_stats(dialogue_report_df) 150 | with open(os.path.join(outdir, f"{subset}.basic_stats.json"), "w") as fp: 151 | fp.write(json.dumps(dataclasses.asdict(basic_stats), indent=2)) 152 | fp.write("\n") 153 | with open(os.path.join(outdir, f"{subset}.percentile_stats.json"), "w") as fp: 154 | fp.write(json.dumps(percentile_stats, indent=2)) 155 | fp.write("\n") 156 | 157 | if len(subsets) > 1: 158 | basic_stats, percentile_stats = compute_stats(pd.concat(dialogue_report_dfs)) 159 | with open( 160 | os.path.join(outdir, f"{'-'.join(subsets)}.basic_stats.json"), "w" 161 | ) as fp: 162 | fp.write(json.dumps(dataclasses.asdict(basic_stats), indent=2)) 163 | fp.write("\n") 164 | with open( 165 | os.path.join(outdir, f"{'-'.join(subsets)}.percentile_stats.json"), "w" 166 | ) as fp: 167 | fp.write(json.dumps(percentile_stats, indent=2)) 168 | fp.write("\n") 169 | 170 | 171 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 172 | argument_parser.add_argument( 173 | "--dataflow_dialogues_dir", help="the dataflow dialogues data directory" 174 | ) 175 | argument_parser.add_argument( 176 | "--subset", nargs="+", default=[], help="the subset to be analyzed" 177 | ) 178 | argument_parser.add_argument("--outdir", help="the output directory") 179 | 180 | 181 | if __name__ == "__main__": 182 | cmdline_parser = argparse.ArgumentParser( 183 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 184 | ) 185 | add_arguments(cmdline_parser) 186 | args = cmdline_parser.parse_args() 187 | 188 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 189 | main( 190 | dataflow_dialogues_dir=args.dataflow_dialogues_dir, 191 | subsets=args.subset, 192 | outdir=args.outdir, 193 | ) 194 | -------------------------------------------------------------------------------- /src/dataflow/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/dataflow/core/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | class SpecialStrings: 6 | """Special strings in stringified turn parts. 7 | """ 8 | 9 | # an empty value (we need it since some library doesn't like an empty string) 10 | NULL = "__NULL" 11 | # indicates there is a break between the two utterance segments 12 | BREAK = "__BREAK" 13 | # indicates the user is the speaker for the following utterance 14 | SPEAKER_USER = "__User" 15 | # indicates the agent is the speaker for the following utterance 16 | SPEAKER_AGENT = "__Agent" 17 | # start of a program 18 | START_OF_PROGRAM = "__StartOfProgram" 19 | -------------------------------------------------------------------------------- /src/dataflow/core/definition.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Tuple 3 | 4 | from dataflow.core.lispress import ( 5 | META_CHAR, 6 | Lispress, 7 | lispress_to_type_name, 8 | parse_lispress, 9 | ) 10 | from dataflow.core.program import TypeName 11 | 12 | 13 | @dataclass(frozen=True) 14 | class Definition: 15 | """A function signature. For example, 16 | Definition("foo", ["T"], [("arg1", TypeName("Long")), ("arg2", TypeName("T"))], TypeName("Double")) 17 | would be 18 | 19 | T = TypeVar("T") 20 | def foo(arg1: Long, arg2: T) -> Double: 21 | pass 22 | 23 | in Python, and 24 | 25 | (def ^(T) foo (^Long arg1 ^T arg2) ^Double ???) 26 | 27 | in Lispress. The ??? is the "body" of the def, much like `pass` in Python. 28 | It's slightly easier there's always a body because that's where return 29 | type annotations live right now. 30 | 31 | This class is currently only used in type_inference.py, but we might use 32 | it elsewhere too.""" 33 | 34 | name: str 35 | type_params: List[str] 36 | params: List[Tuple[str, TypeName]] 37 | return_type: TypeName 38 | 39 | def __post_init__(self): 40 | assert len(set(name for name, typeName in self.params)) == len( 41 | self.params 42 | ), f"Duplicate arg names found for {self.name}. Args were {self.params}" 43 | 44 | 45 | def lispress_library_to_library(lispress_str: str) -> Dict[str, Definition]: 46 | """Parses a list of lispress function defs into a indexed collection of Definitions. 47 | For example an input might look like 48 | 49 | (def + (^Long a ^Long b) ^Long ???) 50 | (package my.namespace 51 | (def - (^Long a ^Long b) ^Long ???) 52 | ) 53 | 54 | The returned library would contain an entry for '+' and 'my.namespace.-'. 55 | """ 56 | # The files are at flat list of global files and namespaces packages. 57 | # Wrap everything in parens to make it parse as a single expression. 58 | sexp = parse_lispress("(" + lispress_str + ")") 59 | assert isinstance( 60 | sexp, list 61 | ), f"Expected list of S-Expressions in file {lispress_str}" 62 | res: Dict[str, Definition] = {} 63 | for def_or_package in sexp: 64 | if isinstance(def_or_package, list) and def_or_package[0] == "def": 65 | # def in the global namespace 66 | defn = _def_to_definition(def_or_package, namespace="") 67 | res[defn.name] = defn 68 | elif isinstance(def_or_package, list) and def_or_package[0] == "package": 69 | assert isinstance( 70 | def_or_package, list 71 | ), f"Expected list of S-Expressions in package {def_or_package}" 72 | (unused_package_kw, package_name, *defs) = def_or_package 73 | for lispress_def in defs: 74 | defn = _def_to_definition(lispress_def, namespace=package_name) 75 | res[defn.name] = defn 76 | return res 77 | 78 | 79 | def _def_to_definition(lispress_def: Lispress, namespace: str) -> Definition: 80 | (unused_keyword, func_name, param_list, body) = lispress_def 81 | if isinstance(func_name, list) and func_name[0] == META_CHAR: 82 | (unused_meta, type_params, actual_func_name) = func_name 83 | else: 84 | actual_func_name = func_name 85 | type_params = [] 86 | 87 | assert ( 88 | isinstance(body, list) and body[0] == META_CHAR 89 | ), f"Invalid function body {body}" 90 | (unused_meta, return_type, unused_body) = body 91 | 92 | params = [ 93 | _parse_param(param) 94 | for param in param_list # Skip typeclass constraints for now 95 | if not (isinstance(param, list) and param[0] == "using") 96 | ] 97 | namespace_prefix = namespace + "." if len(namespace) > 0 else "" 98 | return Definition( 99 | namespace_prefix + actual_func_name, 100 | type_params, 101 | params, 102 | lispress_to_type_name(return_type), 103 | ) 104 | 105 | 106 | def _parse_param(param: Lispress) -> Tuple[str, TypeName]: 107 | assert ( 108 | isinstance(param, list) and param[0] == META_CHAR 109 | ), f"Invalid function param {param}" 110 | (unused_meta, type_ascription, param_name_maybe_with_default) = param 111 | 112 | if isinstance(param_name_maybe_with_default, str): 113 | param_name = param_name_maybe_with_default 114 | elif isinstance(param_name_maybe_with_default, list): 115 | (param_name, unused_default) = param_name_maybe_with_default 116 | # Option[T] is an optional argument, so just make it T 117 | if isinstance(type_ascription, list) and type_ascription[0] == "Option": 118 | type_ascription = type_ascription[1] 119 | return param_name, lispress_to_type_name(type_ascription) 120 | -------------------------------------------------------------------------------- /src/dataflow/core/dialogue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from dataclasses import dataclass 4 | from typing import List, Optional 5 | 6 | from dataflow.core.linearize import lispress_to_seq 7 | from dataflow.core.lispress import lispress_to_program, parse_lispress 8 | from dataflow.core.program import Program 9 | 10 | 11 | @dataclass(frozen=True) 12 | class AgentUtterance: 13 | original_text: str 14 | tokens: List[str] 15 | # The ID of entities described in the agent utterance. 16 | # For programs that have inlined refer calls in our ablative study, the lispress at the 17 | # t-th turn would contain tokens that are entity IDs (e.g., "entity@12345") from 18 | # agent utterances in previous turns. 19 | # In order to make it possible for the seq2seq model to produce such tokens, we concatenate 20 | # these entity IDs in the source sequence side so that the model can learn to "copy" them into 21 | # the target sequence. 22 | # In normal programs with non-inlined refer calls, these entities would be retrieved through 23 | # the refer calls. Thus, we do not need to use this field for normal programs. 24 | described_entities: List[str] 25 | 26 | 27 | @dataclass(frozen=True) 28 | class UserUtterance: 29 | original_text: str 30 | tokens: List[str] 31 | 32 | 33 | @dataclass(frozen=True) 34 | class TurnId: 35 | dialogue_id: str 36 | turn_index: int 37 | 38 | def __hash__(self): 39 | return hash((self.dialogue_id, self.turn_index)) 40 | 41 | 42 | @dataclass(frozen=True) 43 | class ProgramExecutionOracle: 44 | """The oracle information about the program execution. 45 | 46 | Because we have not implemented an executor in Python, we record the 47 | useful information about the execution results on the gold program annotations 48 | for evaluation purpose. 49 | """ 50 | 51 | # the flag to indicate whether the program would raise any exception during execution 52 | has_exception: bool 53 | # the flag to indicate that whether all refer calls in the program would return the correct values during execution 54 | # NOTE: This flag is used in evaluation as if the refer calls are replaced with the concrete program 55 | # fragments they return. This means that a predicted plan is correct iff the plan itself matches the gold plan, 56 | # and the gold plan's refer calls are correct. 57 | refer_are_correct: bool 58 | 59 | 60 | @dataclass(frozen=True) 61 | class Turn: 62 | # the turn index 63 | turn_index: int 64 | # the current user utterance 65 | user_utterance: UserUtterance 66 | # the next agent utterance 67 | agent_utterance: AgentUtterance 68 | # the program corresponding to the user utterance 69 | # (see `dataflow.core.lispress` for the lisp string format) 70 | lispress: str 71 | # the flag to indicate whether to skip this turn when building the datum for training/prediction 72 | # Some turns are skipped to avoid the model from biasing to very common utterances, e.g., "yes", "okay". 73 | # NOTE: These turns are still used for building the dialog context even if the flag is true. 74 | skip: bool 75 | # the oracle information about the gold program 76 | program_execution_oracle: Optional[ProgramExecutionOracle] = None 77 | 78 | def tokenized_lispress(self) -> List[str]: 79 | return lispress_to_seq(parse_lispress(self.lispress)) 80 | 81 | def program(self) -> Program: 82 | program, _ = lispress_to_program(parse_lispress(self.lispress), idx=0) 83 | return program 84 | 85 | 86 | @dataclass(frozen=True) 87 | class Dialogue: 88 | dialogue_id: str 89 | turns: List[Turn] 90 | -------------------------------------------------------------------------------- /src/dataflow/core/io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import json 4 | from collections import defaultdict 5 | from typing import Any, Callable, Dict, Iterable, Iterator, Type, TypeVar, cast 6 | 7 | import jsons 8 | from tqdm import tqdm 9 | 10 | # type of a dataclass object 11 | _TDatum = TypeVar("_TDatum") 12 | # type of the primary key for _TDatum 13 | _TPrimaryKey = TypeVar("_TPrimaryKey") 14 | # type of the secondary key for _TDatum 15 | _TSecondaryKey = TypeVar("_TSecondaryKey") 16 | 17 | 18 | def load_jsonl_file( 19 | data_jsonl: str, cls: Type[_TDatum], unit: str = " items", verbose: bool = True 20 | ) -> Iterator[_TDatum]: 21 | """Loads a jsonl file and yield the deserialized dataclass objects.""" 22 | if verbose: 23 | desc = f"Reading {cls} from {data_jsonl}" 24 | else: 25 | desc = None 26 | with open(data_jsonl, encoding="utf-8") as fp: 27 | for line in tqdm( 28 | fp, desc=desc, unit=unit, dynamic_ncols=True, disable=not verbose 29 | ): 30 | yield jsons.loads(line.strip(), cls=cls) 31 | 32 | 33 | def save_jsonl_file( 34 | data: Iterable[_TDatum], data_jsonl: str, remove_null: bool = True 35 | ) -> None: 36 | """Dumps dataclass objects into a jsonl file.""" 37 | with open(data_jsonl, "w") as fp: 38 | for datum in data: 39 | datum_dict = cast(Dict[str, Any], jsons.dump(datum)) 40 | if remove_null: 41 | fp.write(json.dumps(remove_null_fields_in_dict(datum_dict))) 42 | else: 43 | fp.write(json.dumps(datum_dict)) 44 | fp.write("\n") 45 | 46 | 47 | def load_jsonl_file_and_build_lookup( 48 | data_jsonl: str, 49 | cls: Type[_TDatum], 50 | primary_key_getter: Callable[[_TDatum], _TPrimaryKey], 51 | secondary_key_getter: Callable[[_TDatum], _TSecondaryKey], 52 | unit: str = " items", 53 | verbose: bool = True, 54 | ) -> Dict[_TPrimaryKey, Dict[_TSecondaryKey, _TDatum]]: 55 | """Loads a jsonl file of serialized dataclass objects and returns the lookup with a primary key and a secondary key.""" 56 | if verbose: 57 | desc = f"Reading {cls} from {data_jsonl}" 58 | else: 59 | desc = None 60 | data_lookup: Dict[_TPrimaryKey, Dict[_TSecondaryKey, _TDatum]] = defaultdict(dict) 61 | with open(data_jsonl) as fp: 62 | for line in tqdm( 63 | fp, desc=desc, unit=unit, dynamic_ncols=True, disable=not verbose 64 | ): 65 | datum = jsons.loads(line.strip(), cls) 66 | primary_key = primary_key_getter(datum) 67 | if primary_key not in data_lookup: 68 | data_lookup[primary_key] = {} 69 | data_lookup[primary_key][secondary_key_getter(datum)] = datum 70 | return data_lookup 71 | 72 | 73 | def remove_null_fields_in_dict(raw_dict: Dict[str, Any]) -> Dict[str, Any]: 74 | """Removes null fields in the dict object.""" 75 | return { 76 | key: remove_null_fields_in_dict(val) if isinstance(val, dict) else val 77 | for key, val in raw_dict.items() 78 | if val is not None 79 | } 80 | -------------------------------------------------------------------------------- /src/dataflow/core/linearize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Tools for linearizing a program so that it's easier to predict with a seq2seq model. 5 | """ 6 | from typing import List, Tuple 7 | 8 | from dataflow.core.lispress import ( 9 | LEFT_PAREN, 10 | META_CHAR, 11 | RIGHT_PAREN, 12 | Lispress, 13 | lispress_to_program, 14 | program_to_lispress, 15 | render_compact, 16 | strip_copy_strings, 17 | ) 18 | from dataflow.core.program import Program 19 | from dataflow.core.program_utils import Idx, OpType 20 | from dataflow.core.sexp import Sexp, parse_sexp 21 | 22 | 23 | def to_canonical_form(tokenized_lispress: str) -> str: 24 | """Returns canonical form of a tokenized lispress. 25 | 26 | The canonical form is un-tokenized and compact; it also sorts named arguments in alphabetical order. 27 | """ 28 | lispress = seq_to_lispress(tokenized_lispress.split(" ")) 29 | program, _ = lispress_to_program(lispress, 0) 30 | round_tripped = program_to_lispress(program) 31 | return render_compact(round_tripped) 32 | 33 | 34 | def seq_to_program(seq: List[str], idx: Idx) -> Tuple[Program, Idx]: 35 | lispress = seq_to_lispress(seq) 36 | return lispress_to_program(lispress, idx) 37 | 38 | 39 | def program_to_seq(program: Program) -> List[str]: 40 | lispress = program_to_lispress(program) 41 | return lispress_to_seq(lispress) 42 | 43 | 44 | def lispress_to_seq(lispress: Lispress) -> List[str]: 45 | return sexp_to_seq(lispress) 46 | 47 | 48 | def seq_to_lispress(seq: List[str]) -> Lispress: 49 | return strip_copy_strings(parse_sexp(" ".join(seq))) 50 | 51 | 52 | def sexp_to_seq(s: Sexp) -> List[str]: 53 | if isinstance(s, list): 54 | if len(s) == 3 and s[0] == META_CHAR: 55 | (_meta, type_meta, expr) = s 56 | return [META_CHAR] + [y for x in [type_meta, expr] for y in sexp_to_seq(x)] 57 | elif len(s) == 2 and s[0] == OpType.Value.value: 58 | (_value, expr) = s 59 | return [OpType.Value.value] + [y for x in [expr] for y in sexp_to_seq(x)] 60 | return [LEFT_PAREN] + [y for x in s for y in sexp_to_seq(x)] + [RIGHT_PAREN] 61 | elif isinstance(s, str) and len(s) >= 2 and s[0] == '"' and s[-1] == '"': 62 | return ['"'] + s[1:-1].strip().split(" ") + ['"'] 63 | else: 64 | return [s] 65 | -------------------------------------------------------------------------------- /src/dataflow/core/prediction_report.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import csv 3 | from typing import Dict, List, Sequence, Union 4 | 5 | import pandas as pd 6 | 7 | 8 | class PredictionReportDatum(abc.ABC): 9 | @abc.abstractmethod 10 | def flatten(self) -> Dict[str, Union[str, int]]: 11 | raise NotImplementedError() 12 | 13 | 14 | def save_prediction_report_tsv( 15 | prediction_report: Sequence[PredictionReportDatum], prediction_report_tsv: str, 16 | ) -> None: 17 | """Converts prediction results into a pandas dataframe and saves it a tsv report file. 18 | """ 19 | prediction_report_df = pd.DataFrame( 20 | [datum.flatten() for datum in prediction_report] 21 | ) 22 | prediction_report_df.to_csv( 23 | prediction_report_tsv, 24 | sep="\t", 25 | index=False, 26 | encoding="utf-8", 27 | quoting=csv.QUOTE_ALL, 28 | ) 29 | 30 | 31 | def save_prediction_report_txt( 32 | prediction_report: Sequence[PredictionReportDatum], 33 | prediction_report_txt: str, 34 | field_names: List[str], 35 | ) -> None: 36 | """Prints prediction results into an easy-to-read text report file.""" 37 | with open(prediction_report_txt, "w") as fp: 38 | for datum in prediction_report: 39 | fp.write("=" * 16) 40 | fp.write("\n") 41 | 42 | flatten_fields = datum.flatten() 43 | for field_name in field_names: 44 | field_value = flatten_fields[field_name] 45 | # use "hypo" not "prediction" as the name here just to make it visually aligned with "gold" 46 | if field_name == "prediction": 47 | field_name = "hypo" 48 | print(f"{field_name}\t{field_value}", file=fp) 49 | -------------------------------------------------------------------------------- /src/dataflow/core/program.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from collections import Counter 4 | from dataclasses import dataclass, field 5 | from typing import Dict, List, Optional, Set, Tuple, Union 6 | 7 | from cached_property import cached_property 8 | 9 | 10 | @dataclass(frozen=True) 11 | class ValueOp: 12 | value: str 13 | 14 | 15 | @dataclass(frozen=True) 16 | class CallLikeOp: 17 | name: str 18 | 19 | 20 | @dataclass(frozen=True) 21 | class BuildStructOp: 22 | op_schema: str 23 | # Named arguments. None for positional arguments 24 | op_fields: List[Optional[str]] 25 | empty_base: bool 26 | push_go: bool 27 | 28 | 29 | # NOTE: Currently, since these three ops have different schema and they are not convertible with each other, 30 | # it is type-safe to use `Union` for deserialization. 31 | # See explanation https://pydantic-docs.helpmanual.io/usage/types/#unions. 32 | # If there are two ops sharing the same schema or are convertible between each other, we need to chante 33 | # `Op` to a dataclass and explicitly define a `op_type` field. 34 | Op = Union[ValueOp, CallLikeOp, BuildStructOp] 35 | 36 | 37 | @dataclass(frozen=True) 38 | class TypeName: 39 | base: str 40 | # Tuples preferred so TypeNames can be hashable 41 | type_args: Tuple["TypeName", ...] = field(default_factory=tuple) 42 | 43 | def __repr__(self) -> str: 44 | if len(self.type_args) == 0: 45 | return self.base 46 | else: 47 | return f'({self.base} {" ".join(a.__repr__() for a in self.type_args)})' 48 | 49 | 50 | @dataclass(frozen=True) 51 | class Expression: 52 | id: str 53 | op: Op 54 | type_args: Optional[List[TypeName]] = None 55 | type: Optional[TypeName] = None 56 | arg_ids: List[str] = field(default_factory=list) 57 | 58 | 59 | @dataclass(frozen=True) 60 | class Program: 61 | expressions: List[Expression] 62 | 63 | @cached_property 64 | def expressions_by_id(self) -> Dict[str, Expression]: 65 | return {expression.id: expression for expression in self.expressions} 66 | 67 | 68 | def roots_and_reentrancies(program: Program) -> Tuple[List[str], Set[str]]: 69 | """ 70 | Returns ids of roots (expressions that never appear as arguments) and 71 | reentrancies (expressions that appear more than once as arguments). 72 | Now that `do` expressions get their own nodes, there should be exactly 73 | one root. 74 | """ 75 | arg_counts = Counter(a for e in program.expressions for a in e.arg_ids) 76 | # ids that are never used as args 77 | roots = [e.id for e in program.expressions if e.id not in arg_counts] 78 | # args that are used multiple times as args 79 | reentrancies = {i for i, c in arg_counts.items() if c >= 2} 80 | return roots, reentrancies 81 | -------------------------------------------------------------------------------- /src/dataflow/core/program_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import re 4 | from enum import Enum 5 | from json import dumps 6 | from typing import Any, List, Optional, Tuple 7 | 8 | from dataflow.core.program import ( 9 | BuildStructOp, 10 | CallLikeOp, 11 | Expression, 12 | TypeName, 13 | ValueOp, 14 | ) 15 | from dataflow.core.sexp import Sexp 16 | 17 | # revise args 18 | ROOT_LOCATION = "rootLocation" 19 | OLD_LOCATION = "oldLocation" 20 | NEW = "new" 21 | 22 | # BuildStructOp special arg 23 | NON_EMPTY_BASE = "nonEmptyBase" 24 | 25 | Idx = int 26 | 27 | 28 | class OpType(Enum): 29 | """The type of an op.""" 30 | 31 | Call = "Call" 32 | Struct = "Struct" 33 | Value = "#" 34 | 35 | 36 | class DataflowFn(Enum): 37 | """Special Dataflow functions""" 38 | 39 | Do = "do" # keyword for sequencing programs that have multiple statements 40 | Find = "find" # search 41 | Abandon = "abandon" 42 | Revise = "ReviseConstraint" 43 | Refer = "refer" 44 | RoleConstraint = "roleConstraint" 45 | Get = "get" # access a member field of an object 46 | # keyword that introduces an anonymous function with a single parameter 47 | # syntax: `(lambda (^ArgType xN) arg_body)` 48 | # where `N` is an int and `arg_body` may reference `xN` 49 | Lambda = "lambda" 50 | # special 0-arg function that stands in for a lambda's argument 51 | LambdaArg = "lambda_arg" 52 | 53 | 54 | def idx_str(idx: Idx) -> str: 55 | return f"[{idx}]" 56 | 57 | 58 | def is_idx_str(s: str) -> bool: 59 | return s.startswith("[") and s.endswith("]") 60 | 61 | 62 | def unwrap_idx_str(s: str) -> int: 63 | return int(s[1:-1]) 64 | 65 | 66 | def is_struct_op_schema(name: str) -> bool: 67 | """BuildStructOp schemas begin with a capital letter.""" 68 | if len(name) == 0: 69 | return False 70 | return re.match(r"[A-Z]", name[0]) is not None 71 | 72 | 73 | def get_named_args(e: Expression) -> List[Tuple[str, Optional[str]]]: 74 | """ 75 | Gets a list of (arg_name, arg_id) pairs. 76 | If `e` is a BuildStructOp, then `arg_names` are its `fields`, otherwise 77 | they are the 0-indexed argument position. 78 | """ 79 | if isinstance(e.op, BuildStructOp): 80 | bso = e.op 81 | # a non-empty BuildStructOp has an implicit 0-th field name 82 | zeroth_field = [] if bso.empty_base else [NON_EMPTY_BASE] 83 | fields = zeroth_field + list(bso.op_fields) 84 | else: 85 | fields = [f"arg{i}" for i in range(len(e.arg_ids))] 86 | return list(zip(fields, e.arg_ids)) 87 | 88 | 89 | def mk_constraint( 90 | tpe: str, args: List[Tuple[Optional[str], int]], idx: Idx, 91 | ) -> Tuple[Expression, Idx]: 92 | return mk_struct_op(schema=f"Constraint[{tpe.capitalize()}]", args=args, idx=idx) 93 | 94 | 95 | def mk_equality_constraint(val: int, idx: Idx) -> Tuple[Expression, Idx]: 96 | return mk_call_op(name="?=", args=[val], idx=idx) 97 | 98 | 99 | def mk_unset_constraint(idx: Idx) -> Tuple[Expression, Idx]: 100 | return mk_struct_op(schema="EmptyConstraint", args=[], idx=idx) 101 | 102 | 103 | def mk_salience(tpe: str, idx: Idx) -> Tuple[List[Expression], Idx]: 104 | constraint_expr, constraint_idx = mk_constraint(tpe=tpe, args=[], idx=idx) 105 | salience_expr, idx = mk_call_op( 106 | name=DataflowFn.Refer.value, args=[constraint_idx], idx=constraint_idx 107 | ) 108 | return [constraint_expr, salience_expr], idx 109 | 110 | 111 | def mk_salient_action(idx: Idx) -> Tuple[List[Expression], Idx]: 112 | """ (roleConstraint #(Path "output")) """ 113 | path_expr, path_idx = mk_value_op(schema="Path", value="output", idx=idx) 114 | intension_expr, intension_idx = mk_call_op( 115 | name=DataflowFn.RoleConstraint.value, args=[path_idx], idx=path_idx, 116 | ) 117 | return [path_expr, intension_expr], intension_idx 118 | 119 | 120 | def mk_revise( 121 | root_location_idx: Idx, old_location_idx: Idx, new_idx: Idx, idx: Idx, 122 | ) -> Tuple[Expression, Idx]: 123 | """ 124 | Revises the salient constraint satisfying the constraint at `old_location_idx`, 125 | in the salient computation satisfying the constraint at `root_location_idx`, 126 | with the constraint at `new_idx`. 127 | In Lispress: 128 | ``` 129 | (Revise 130 | :rootLocation {root_location} 131 | :oldLocation {old_location} 132 | :new {new}) 133 | """ 134 | return mk_struct_op( 135 | schema=DataflowFn.Revise.value, 136 | args=[ 137 | (ROOT_LOCATION, root_location_idx), 138 | (OLD_LOCATION, old_location_idx), 139 | (NEW, new_idx), 140 | ], 141 | idx=idx, 142 | ) 143 | 144 | 145 | def mk_revise_the_main_constraint( 146 | tpe: str, new_idx: Idx 147 | ) -> Tuple[List[Expression], Idx]: 148 | """ 149 | Revises the salient constraint (on values of type `tpe`) in the salient action, with the 150 | constraint at `new_idx`. 151 | (An "action" is an argument of `Yield`). 152 | In Lispress: 153 | ``` 154 | (ReviseConstraint 155 | :rootLocation (RoleConstraint :role #(Path "output")) 156 | :oldLocation (Constraint[Constraint[{tpe}]]) 157 | :new {new}) 158 | ``` 159 | """ 160 | salient_action_exprs, salient_action_idx = mk_salient_action(new_idx) 161 | old_loc_expr, old_loc_idx = mk_struct_op( 162 | schema=f"Constraint[Constraint[{tpe.capitalize()}]]", 163 | args=[], 164 | idx=salient_action_idx, 165 | ) 166 | revise_expr, revise_idx = mk_revise( 167 | root_location_idx=salient_action_idx, 168 | old_location_idx=old_loc_idx, 169 | new_idx=new_idx, 170 | idx=old_loc_idx, 171 | ) 172 | return salient_action_exprs + [old_loc_expr, revise_expr], revise_idx 173 | 174 | 175 | def mk_struct_op( 176 | schema: str, args: List[Tuple[Optional[str], Idx]], idx: Idx, 177 | ) -> Tuple[Expression, Idx]: 178 | new_idx = idx + 1 179 | base = next((v for k, v in args if k == NON_EMPTY_BASE), None) 180 | is_empty_base = base is None 181 | arg_names = [k for k, v in args] 182 | # nonEmptyBase always comes first 183 | arg_vals = ([] if is_empty_base else [base]) + [v for k, v in args] 184 | flat_exp = Expression( 185 | id=idx_str(new_idx), 186 | op=BuildStructOp( 187 | op_schema=schema, 188 | op_fields=arg_names, 189 | empty_base=is_empty_base, 190 | push_go=True, 191 | ), 192 | arg_ids=[idx_str(v) for v in arg_vals], 193 | ) 194 | return flat_exp, new_idx 195 | 196 | 197 | def mk_call_op( 198 | name: str, 199 | args: List[Idx], 200 | type_args: Optional[List[TypeName]] = None, 201 | idx: Idx = 0, 202 | ) -> Tuple[Expression, Idx]: 203 | new_idx = idx + 1 204 | flat_exp = Expression( 205 | id=idx_str(new_idx), 206 | op=CallLikeOp(name=name), 207 | type_args=type_args, 208 | arg_ids=[idx_str(v) for v in args], 209 | ) 210 | return flat_exp, new_idx 211 | 212 | 213 | def mk_type_name(sexp: Sexp) -> TypeName: 214 | if isinstance(sexp, str): 215 | return TypeName(sexp, ()) 216 | hd, *tl = sexp 217 | return TypeName(hd, tuple([mk_type_name(e) for e in tl])) 218 | 219 | 220 | def mk_value_op(value: Any, schema: str, idx: Idx) -> Tuple[Expression, Idx]: 221 | my_idx = idx + 1 222 | dumped = dumps({"schema": schema, "underlying": value}) 223 | expr = Expression(id=idx_str(my_idx), op=ValueOp(value=dumped)) 224 | return expr, my_idx 225 | 226 | 227 | def mk_lambda_arg(type_name: TypeName, idx: Idx = 0) -> Tuple[Expression, Idx]: 228 | return mk_call_op( 229 | name=DataflowFn.LambdaArg.value, type_args=[type_name], args=[], idx=idx 230 | ) 231 | 232 | 233 | def mk_lambda(arg_idx: Idx, body_idx: Idx, idx: Idx = 0) -> Tuple[Expression, Idx]: 234 | return mk_call_op(name=DataflowFn.Lambda.value, args=[arg_idx, body_idx], idx=idx) 235 | -------------------------------------------------------------------------------- /src/dataflow/core/sexp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from enum import Enum 4 | from typing import List, Union 5 | 6 | LEFT_PAREN = "(" 7 | RIGHT_PAREN = ")" 8 | ESCAPE = "\\" 9 | DOUBLE_QUOTE = '"' 10 | META = "^" 11 | READER = "#" 12 | 13 | # we unwrap Symbols into strings for convenience 14 | Sexp = Union[str, List["Sexp"]] # type: ignore # Recursive type 15 | 16 | 17 | class QuoteState(Enum): 18 | """Whether a char is inside double quotes or not during parsing""" 19 | 20 | Inside = True 21 | Outside = False 22 | 23 | def flipped(self) -> "QuoteState": 24 | return QuoteState(not self.value) 25 | 26 | 27 | def _split_respecting_quotes(s: str) -> List[str]: 28 | """Splits `s` on whitespace that is not inside double quotes.""" 29 | result = [] 30 | marked_idx = 0 31 | state = QuoteState.Outside 32 | for idx, ch in enumerate(s): 33 | if ch == DOUBLE_QUOTE and (idx < 1 or s[idx - 1] != ESCAPE): 34 | if state == QuoteState.Outside: 35 | result.extend(s[marked_idx:idx].strip().split()) 36 | marked_idx = idx 37 | else: 38 | result.append(s[marked_idx : idx + 1]) # include the double quote 39 | marked_idx = idx + 1 40 | state = state.flipped() 41 | assert state == QuoteState.Outside, f"Mismatched double quotes: {s}" 42 | if marked_idx < len(s): 43 | result.extend(s[marked_idx:].strip().split()) 44 | return result 45 | 46 | 47 | def parse_sexp(s: str) -> Sexp: 48 | offset = 0 49 | 50 | # eoi = end of input 51 | def is_eoi(): 52 | nonlocal offset 53 | return offset == len(s) 54 | 55 | def peek(): 56 | nonlocal offset 57 | return s[offset] 58 | 59 | def next_char(): 60 | # pylint: disable=used-before-assignment 61 | nonlocal offset 62 | cn = s[offset] 63 | offset += 1 64 | return cn 65 | 66 | def skip_whitespace(): 67 | while (not is_eoi()) and peek().isspace(): 68 | next_char() 69 | 70 | def skip_then_peek(): 71 | skip_whitespace() 72 | return peek() 73 | 74 | def read() -> Sexp: 75 | skip_whitespace() 76 | c = next_char() 77 | if c == LEFT_PAREN: 78 | return read_list() 79 | elif c == DOUBLE_QUOTE: 80 | return read_string() 81 | elif c == META: 82 | meta = read() 83 | expr = read() 84 | return [META, meta, expr] 85 | elif c == READER: 86 | return [READER, read()] 87 | else: 88 | out_inner = "" 89 | if c != "\\": 90 | out_inner += c 91 | 92 | # TODO: is there a better loop idiom here? 93 | if not is_eoi(): 94 | next_c = peek() 95 | escaped = c == "\\" 96 | while (not is_eoi()) and ( 97 | escaped or not _is_beginning_control_char(next_c) 98 | ): 99 | if (not escaped) and next_c == "\\": 100 | next_char() 101 | escaped = True 102 | else: 103 | out_inner += next_char() 104 | escaped = False 105 | if not is_eoi(): 106 | next_c = peek() 107 | return out_inner 108 | 109 | def read_list(): 110 | out_list = [] 111 | while skip_then_peek() != RIGHT_PAREN: 112 | out_list.append(read()) 113 | next_char() 114 | return out_list 115 | 116 | def read_string(): 117 | out_str = "" 118 | while peek() != '"': 119 | c_string = next_char() 120 | out_str += c_string 121 | if c_string == "\\": 122 | out_str += next_char() 123 | next_char() 124 | return f'"{out_str}"' 125 | 126 | out = read() 127 | skip_whitespace() 128 | assert offset == len( 129 | s 130 | ), f"Failed to exhaustively parse {s}, maybe you are missing a close paren?" 131 | return out 132 | 133 | 134 | def _is_beginning_control_char(nextC): 135 | return ( 136 | nextC.isspace() 137 | or nextC == LEFT_PAREN 138 | or nextC == RIGHT_PAREN 139 | or nextC == DOUBLE_QUOTE 140 | or nextC == READER 141 | or nextC == META 142 | ) 143 | 144 | 145 | def sexp_to_str(sexp: Sexp) -> str: 146 | """ Generates string representation from S-expression """ 147 | # Note that some of this logic is repeated in lispress.render_pretty 148 | if isinstance(sexp, list): 149 | if len(sexp) == 3 and sexp[0] == META: 150 | (_meta, type_expr, underlying_expr) = sexp 151 | return META + sexp_to_str(type_expr) + " " + sexp_to_str(underlying_expr) 152 | elif len(sexp) == 2 and sexp[0] == READER: 153 | (_reader, expr) = sexp 154 | return READER + sexp_to_str(expr) 155 | else: 156 | return "(" + " ".join(sexp_to_str(f) for f in sexp) + ")" 157 | else: 158 | if sexp.startswith('"') and sexp.endswith('"'): 159 | return sexp 160 | else: 161 | return _escape_symbol(sexp) 162 | 163 | 164 | def _escape_symbol(symbol: str) -> str: 165 | out = [] 166 | for c in symbol: 167 | if _is_beginning_control_char(c): 168 | out.append("\\") 169 | out.append(c) 170 | return "".join(out) 171 | -------------------------------------------------------------------------------- /src/dataflow/core/turn_prediction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | from dataflow.core.dialogue import ( 7 | Dialogue, 8 | ProgramExecutionOracle, 9 | TurnId, 10 | UserUtterance, 11 | ) 12 | 13 | 14 | @dataclass(frozen=True, eq=True, repr=True) 15 | class UtteranceWithContext: 16 | """ 17 | A user utterance, with the dialogue history leading up to it. 18 | This is the input to the lispress prediction task. 19 | """ 20 | 21 | datum_id: TurnId 22 | user_utterance: UserUtterance 23 | context: Dialogue 24 | 25 | 26 | @dataclass(frozen=True, eq=True, repr=True) 27 | class TurnPrediction: 28 | """ 29 | A model prediction of the `lispress` for a single Turn. 30 | This is the output of the lispress prediction task. 31 | """ 32 | 33 | datum_id: TurnId 34 | user_utterance: str # redundant. just to make these files easier to read 35 | lispress: str 36 | 37 | 38 | @dataclass(frozen=True, eq=True, repr=True) 39 | class TurnAnswer: 40 | """ 41 | A model prediction of the `lispress` for a single Turn. 42 | This is the output of the lispress prediction task. 43 | """ 44 | 45 | datum_id: TurnId 46 | user_utterance: str # redundant. just to make these files easier to read 47 | lispress: str 48 | program_execution_oracle: Optional[ProgramExecutionOracle] 49 | 50 | 51 | def missing_prediction(datum_id: TurnId) -> TurnPrediction: 52 | """ 53 | A padding `TurnPrediction` that is used when a turn with 54 | `datum_id` is missing from a predictions file. 55 | """ 56 | return TurnPrediction( 57 | datum_id=datum_id, user_utterance="", lispress="", 58 | ) 59 | -------------------------------------------------------------------------------- /src/dataflow/core/utterance_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import re 4 | from typing import List 5 | 6 | import spacy 7 | from spacy.language import Language 8 | 9 | from dataflow.core.constants import SpecialStrings 10 | 11 | 12 | def tokenize_datetime(text: str) -> str: 13 | """Tokenizes datetime to make it consistent with the seaweed tokens.""" 14 | # 5.10 => 5 . 10 15 | # 4:00 => 4 : 00 16 | # 5/7 => 5 / 7 17 | # 5\7 => 5 \ 7 18 | # 3-9 => 3 - 9 19 | text = re.sub(r"(\d)([.:/\\-])(\d)", r"\1 \2 \3", text) 20 | 21 | # 4pm => 4 pm 22 | text = re.sub(r"(\d+)([a-zA-Z])", r"\1 \2", text) 23 | 24 | # safe guard to avoid multiple spaces 25 | text = re.sub(r"\s+", " ", text) 26 | return text 27 | 28 | 29 | class UtteranceTokenizer: 30 | """A Spacy-based tokenizer with some heuristics for user utterances.""" 31 | 32 | def __init__(self, spacy_model_name: str = "en_core_web_md") -> None: 33 | self._spacy_nlp: Language = spacy.load(spacy_model_name) 34 | 35 | def tokenize(self, utterance_str: str) -> List[str]: 36 | """Tokenizes the utterance string and returns a list of tokens. 37 | """ 38 | if not utterance_str: 39 | return [] 40 | 41 | if utterance_str == SpecialStrings.NULL: 42 | # do not tokenize the NULL special string 43 | return [utterance_str] 44 | 45 | tokens: List[str] = sum( 46 | [ 47 | tokenize_datetime(token.text).split(" ") 48 | for token in self._spacy_nlp(utterance_str) 49 | ], 50 | [], 51 | ) 52 | return tokens 53 | -------------------------------------------------------------------------------- /src/dataflow/core/utterance_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import re 4 | from typing import List 5 | 6 | from dataflow.core.constants import SpecialStrings 7 | from dataflow.core.dialogue import AgentUtterance, UserUtterance 8 | from dataflow.core.utterance_tokenizer import UtteranceTokenizer 9 | 10 | 11 | def clean_utterance_text(text: str) -> str: 12 | """Removes line breaking and extra spaces in the user utterance.""" 13 | # sometimes the user utterance contains line breaking and extra spaces 14 | text = re.sub(r"\s+", " ", text) 15 | # sometimes the user utterance has leading/ending spaces 16 | text = text.strip() 17 | return text 18 | 19 | 20 | def build_user_utterance( 21 | text: str, utterance_tokenizer: UtteranceTokenizer 22 | ) -> UserUtterance: 23 | text = clean_utterance_text(text) 24 | if not text: 25 | return UserUtterance( 26 | original_text=SpecialStrings.NULL, tokens=[SpecialStrings.NULL] 27 | ) 28 | return UserUtterance(original_text=text, tokens=utterance_tokenizer.tokenize(text)) 29 | 30 | 31 | def build_agent_utterance( 32 | text: str, utterance_tokenizer: UtteranceTokenizer, described_entities: List[str] 33 | ) -> AgentUtterance: 34 | text = clean_utterance_text(text) 35 | if not text: 36 | return AgentUtterance( 37 | original_text=SpecialStrings.NULL, 38 | tokens=[SpecialStrings.NULL], 39 | described_entities=described_entities, 40 | ) 41 | return AgentUtterance( 42 | original_text=text, 43 | tokens=utterance_tokenizer.tokenize(text), 44 | described_entities=described_entities, 45 | ) 46 | -------------------------------------------------------------------------------- /src/dataflow/leaderboard/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/dataflow/leaderboard/create_leaderboard_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Converts native Calflow data to the format used by the leaderboard. 7 | """ 8 | import argparse 9 | import hashlib 10 | import random 11 | from typing import List 12 | 13 | from dataflow.core.dialogue import Dialogue, TurnId 14 | from dataflow.core.io_utils import load_jsonl_file, save_jsonl_file 15 | from dataflow.core.turn_prediction import TurnAnswer, UtteranceWithContext 16 | 17 | 18 | def main( 19 | dataflow_dialogues_jsonl: str, 20 | dialogue_id_prefix: str, 21 | contextualized_turns_file: str, 22 | turn_answers_file: str, 23 | ) -> None: 24 | contextualized_turns: List[UtteranceWithContext] = [] 25 | turn_predictons: List[TurnAnswer] = [] 26 | 27 | for dialogue in load_jsonl_file( 28 | data_jsonl=dataflow_dialogues_jsonl, cls=Dialogue, unit=" dialogues" 29 | ): 30 | for turn_index, turn in enumerate(dialogue.turns): 31 | if turn.skip: 32 | continue 33 | full_dialogue_id = ( 34 | dialogue_id_prefix 35 | + "-" 36 | + hashlib.sha1( 37 | str.encode(dialogue.dialogue_id + ":" + str(turn.turn_index)) 38 | ).hexdigest() 39 | ) 40 | datum_id = TurnId(full_dialogue_id, turn.turn_index) 41 | contextualized_turn = UtteranceWithContext( 42 | datum_id=datum_id, 43 | user_utterance=turn.user_utterance, 44 | context=Dialogue( 45 | dialogue_id=full_dialogue_id, turns=dialogue.turns[:turn_index], 46 | ), 47 | ) 48 | contextualized_turns.append(contextualized_turn) 49 | turn_predictons.append( 50 | TurnAnswer( 51 | datum_id=datum_id, 52 | user_utterance=turn.user_utterance.original_text, 53 | lispress=turn.lispress, 54 | program_execution_oracle=turn.program_execution_oracle, 55 | ) 56 | ) 57 | 58 | random.shuffle(contextualized_turns) 59 | save_jsonl_file(contextualized_turns, contextualized_turns_file) 60 | save_jsonl_file(turn_predictons, turn_answers_file) 61 | 62 | 63 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 64 | argument_parser.add_argument( 65 | "--dialogues_jsonl", 66 | help="the jsonl file containing the dialogue data with dataflow programs", 67 | ) 68 | argument_parser.add_argument( 69 | "--contextualized_turns_file", help="the output file", 70 | ) 71 | argument_parser.add_argument( 72 | "--turn_answers_file", help="the output file", 73 | ) 74 | argument_parser.add_argument( 75 | "--dialogue_id_prefix", help="dialogue id prefix", 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | cmdline_parser = argparse.ArgumentParser( 81 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 82 | ) 83 | add_arguments(cmdline_parser) 84 | args = cmdline_parser.parse_args() 85 | 86 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 87 | main( 88 | dataflow_dialogues_jsonl=args.dialogues_jsonl, 89 | dialogue_id_prefix=args.dialogue_id_prefix, 90 | contextualized_turns_file=args.contextualized_turns_file, 91 | turn_answers_file=args.turn_answers_file, 92 | ) 93 | -------------------------------------------------------------------------------- /src/dataflow/leaderboard/create_text_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Creates text data (source-target pairs) to be used for training OpenNMT models. 7 | """ 8 | import argparse 9 | import dataclasses 10 | from typing import Dict, Iterator 11 | 12 | import jsons 13 | from tqdm import tqdm 14 | 15 | from dataflow.core.dialogue import AgentUtterance, Turn 16 | from dataflow.core.turn_prediction import UtteranceWithContext 17 | from dataflow.onmt_helpers.create_onmt_text_data import ( 18 | OnmtTextDatum, 19 | create_context_turns, 20 | create_onmt_text_datum_for_turn, 21 | ) 22 | 23 | # We assume all dialogues start from turn 0. 24 | # This is true for MultiWoZ and CalFlow datasets. 25 | _MIN_TURN_INDEX = 0 26 | 27 | 28 | def create_onmt_text_data_for_contextualized_turn( 29 | contextualized_turn: UtteranceWithContext, 30 | num_context_turns: int, 31 | min_turn_index: int, 32 | include_program: bool, 33 | include_agent_utterance: bool, 34 | include_described_entities: bool, 35 | ) -> Iterator[OnmtTextDatum]: 36 | """Yields OnmtTextDatum for a dialogue.""" 37 | turn_lookup: Dict[int, Turn] = { 38 | turn.turn_index: turn for turn in contextualized_turn.context.turns 39 | } 40 | context_turns = create_context_turns( 41 | turn_lookup=turn_lookup, 42 | curr_turn_index=contextualized_turn.datum_id.turn_index, 43 | num_context_turns=num_context_turns, 44 | min_turn_index=min_turn_index, 45 | ) 46 | onmt_text_datum = create_onmt_text_datum_for_turn( 47 | dialogue_id=contextualized_turn.datum_id.dialogue_id, 48 | curr_turn=Turn( 49 | turn_index=contextualized_turn.datum_id.turn_index, 50 | user_utterance=contextualized_turn.user_utterance, 51 | agent_utterance=AgentUtterance( 52 | original_text="", tokens=[], described_entities=[] 53 | ), 54 | lispress="()", 55 | skip=False, 56 | ), 57 | context_turns=context_turns, 58 | include_program=include_program, 59 | include_agent_utterance=include_agent_utterance, 60 | include_described_entities=include_described_entities, 61 | ) 62 | yield onmt_text_datum 63 | 64 | 65 | def main( 66 | dataflow_dialogues_jsonl: str, 67 | num_context_turns: int, 68 | min_turn_index: int, 69 | include_program: bool, 70 | include_agent_utterance: bool, 71 | include_described_entities: bool, 72 | onmt_text_data_outbase: str, 73 | ) -> None: 74 | fps = OnmtTextDatum.create_output_files(onmt_text_data_outbase) 75 | 76 | for line in tqdm(open(dataflow_dialogues_jsonl), unit=" contextualized turns"): 77 | contextualized_turn = jsons.loads(line.strip(), UtteranceWithContext) 78 | for onmt_text_datum in create_onmt_text_data_for_contextualized_turn( 79 | contextualized_turn=contextualized_turn, 80 | num_context_turns=num_context_turns, 81 | min_turn_index=min_turn_index, 82 | include_program=include_program, 83 | include_agent_utterance=include_agent_utterance, 84 | include_described_entities=include_described_entities, 85 | ): 86 | for field_name, field_value in dataclasses.asdict(onmt_text_datum).items(): 87 | fp = fps[field_name] 88 | fp.write(field_value) 89 | fp.write("\n") 90 | 91 | for _, fp in fps.items(): 92 | fp.close() 93 | 94 | 95 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 96 | argument_parser.add_argument( 97 | "--dialogues_jsonl", 98 | help="the jsonl file containing the dialogue data with dataflow programs", 99 | ) 100 | argument_parser.add_argument( 101 | "--num_context_turns", 102 | type=int, 103 | help="number of previous turns to be included in the source sequence", 104 | ) 105 | argument_parser.add_argument( 106 | "--include_program", 107 | default=False, 108 | action="store_true", 109 | help="if True, include the gold program for the context turn parts", 110 | ) 111 | argument_parser.add_argument( 112 | "--include_agent_utterance", 113 | default=False, 114 | action="store_true", 115 | help="if True, include the gold agent utterance for the context turn parts", 116 | ) 117 | argument_parser.add_argument( 118 | "--include_described_entities", 119 | default=False, 120 | action="store_true", 121 | help="if True, include the described entities field for the context turn parts", 122 | ) 123 | argument_parser.add_argument( 124 | "--onmt_text_data_outbase", 125 | help="the output file basename for the extracted text data for OpenNMT", 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | cmdline_parser = argparse.ArgumentParser( 131 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 132 | ) 133 | add_arguments(cmdline_parser) 134 | args = cmdline_parser.parse_args() 135 | 136 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 137 | main( 138 | dataflow_dialogues_jsonl=args.dialogues_jsonl, 139 | num_context_turns=args.num_context_turns, 140 | min_turn_index=_MIN_TURN_INDEX, 141 | include_program=args.include_program, 142 | include_agent_utterance=args.include_agent_utterance, 143 | include_described_entities=args.include_described_entities, 144 | onmt_text_data_outbase=args.onmt_text_data_outbase, 145 | ) 146 | -------------------------------------------------------------------------------- /src/dataflow/leaderboard/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Evaluation script for the leaderboard. 7 | """ 8 | import argparse 9 | import json 10 | from typing import Iterable, List, Optional, Set, Tuple 11 | 12 | from dataflow.core.dialogue import TurnId 13 | from dataflow.core.io_utils import load_jsonl_file 14 | from dataflow.core.lispress import try_round_trip 15 | from dataflow.core.turn_prediction import TurnAnswer, TurnPrediction, missing_prediction 16 | 17 | 18 | def evaluate_prediction_exact_match( 19 | pred: TurnPrediction, gold: TurnAnswer 20 | ) -> Tuple[bool, bool]: 21 | assert pred.datum_id == gold.datum_id, f"mismatched data: {pred}, {gold}" 22 | pred_lispress = try_round_trip(pred.lispress) 23 | gold_lispress = try_round_trip(gold.lispress) 24 | if pred_lispress != gold_lispress: 25 | print( 26 | f"Misprediction on {gold.datum_id.dialogue_id}:{gold.datum_id.turn_index} | {gold.user_utterance}\nPred: {pred_lispress}\nGold: {gold_lispress}\n" 27 | ) 28 | elif not gold.program_execution_oracle.refer_are_correct: 29 | print( 30 | f"Example {gold.datum_id.dialogue_id}:{gold.datum_id.turn_index} can't be correct because the refer call is not correct.\n" 31 | ) 32 | return ( 33 | pred_lispress == gold_lispress 34 | and gold.program_execution_oracle.refer_are_correct, 35 | pred_lispress == gold_lispress, 36 | ) 37 | 38 | 39 | def evaluate_predictions_exact_match( 40 | preds_and_golds: Iterable[Tuple[TurnPrediction, TurnAnswer]] 41 | ) -> Tuple[float, float]: 42 | correct = 0 43 | correct_ignoring_refer = 0 44 | total = 0 45 | for pred, gold in preds_and_golds: 46 | total += 1 47 | (right, right_ignoring_refer) = evaluate_prediction_exact_match(pred, gold) 48 | correct += int(right) 49 | correct_ignoring_refer += int(right_ignoring_refer) 50 | 51 | return ( 52 | correct / total if total else 0, 53 | correct_ignoring_refer / total if total else 0, 54 | ) 55 | 56 | 57 | def collate( 58 | preds: List[TurnPrediction], 59 | golds: List[TurnAnswer], 60 | datum_ids: Optional[Set[TurnId]], 61 | ) -> List[Tuple[TurnPrediction, TurnAnswer]]: 62 | """ 63 | For each datum `gold` in `golds`, if `gold.datum_id` is in `datum_ids`, 64 | return a tuple of `(pred, gold)`, where `pred` is in `preds` and 65 | `pred.datum_id == gold.datum_id` 66 | If no such `pred` exists, `gold` is paired with a special "missing" 67 | prediction which is never correct. 68 | """ 69 | pred_by_id = {pred.datum_id: pred for pred in preds} 70 | pred_ids = set(pred_by_id.keys()) 71 | gold_ids = {gold.datum_id for gold in golds} 72 | if datum_ids is not None: 73 | gold_ids &= datum_ids 74 | missing_ids = gold_ids - pred_ids 75 | extra_ids = pred_ids - gold_ids 76 | if missing_ids: 77 | print(f"Gold turns not predicted: {list(missing_ids)}") 78 | if extra_ids: 79 | pass 80 | return [ 81 | (pred_by_id.get(gold.datum_id, missing_prediction(gold.datum_id)), gold) 82 | for gold in golds 83 | if datum_ids is None or gold.datum_id in datum_ids 84 | ] 85 | 86 | 87 | def evaluate_prediction_file( 88 | predictions_jsonl: str, gold_jsonl: str, datum_ids_jsonl: Optional[str] 89 | ) -> Tuple[float, float]: 90 | preds = list(load_jsonl_file(predictions_jsonl, TurnPrediction, verbose=False)) 91 | golds = list(load_jsonl_file(gold_jsonl, TurnAnswer, verbose=False)) 92 | datum_ids = ( 93 | None 94 | if datum_ids_jsonl is None 95 | else set(load_jsonl_file(data_jsonl=datum_ids_jsonl, cls=TurnId, verbose=False)) 96 | ) 97 | return evaluate_predictions_exact_match(collate(preds, golds, datum_ids)) 98 | 99 | 100 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 101 | argument_parser.add_argument( 102 | "--predictions_jsonl", help="the predictions jsonl file to evaluate", 103 | ) 104 | argument_parser.add_argument( 105 | "--gold_jsonl", help="the gold jsonl file to evaluate against", 106 | ) 107 | argument_parser.add_argument( 108 | "--datum_ids_jsonl", default=None, help="if set, only evaluate on these turns", 109 | ) 110 | argument_parser.add_argument("--scores_json", help="output scores json file") 111 | 112 | 113 | def write_accuracy_json( 114 | accuracies: Tuple[float, float], scores_json_filename: str 115 | ) -> None: 116 | (accuracy, accuracy_ignoring_refer) = accuracies 117 | with open(scores_json_filename, mode="w", encoding="utf8") as scores_json_file: 118 | scores_json_file.write( 119 | json.dumps( 120 | { 121 | "accuracy": accuracy, 122 | "accuracy_ignorning_refer": accuracy_ignoring_refer, 123 | } 124 | ) 125 | ) 126 | 127 | 128 | def main(): 129 | cmdline_parser = argparse.ArgumentParser( 130 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 131 | ) 132 | add_arguments(cmdline_parser) 133 | args = cmdline_parser.parse_args() 134 | 135 | accuracies = evaluate_prediction_file( 136 | predictions_jsonl=args.predictions_jsonl, 137 | gold_jsonl=args.gold_jsonl, 138 | datum_ids_jsonl=args.datum_ids_jsonl, 139 | ) 140 | write_accuracy_json(accuracies, args.scores_json) 141 | 142 | 143 | if __name__ == "__main__": 144 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 145 | main() 146 | -------------------------------------------------------------------------------- /src/dataflow/leaderboard/predict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Creates the prediction files from onmt_translate output for the leaderboard. 7 | """ 8 | import argparse 9 | from typing import List 10 | 11 | import jsons 12 | from more_itertools import chunked 13 | 14 | from dataflow.core.dialogue import TurnId 15 | from dataflow.core.io_utils import save_jsonl_file 16 | from dataflow.core.turn_prediction import TurnPrediction 17 | 18 | 19 | def build_prediction_report_datum( 20 | datum_id_line: str, src_line: str, nbest_lines: List[str], 21 | ) -> TurnPrediction: 22 | datum_id = jsons.loads(datum_id_line.strip(), TurnId) 23 | return TurnPrediction( 24 | datum_id=datum_id, 25 | user_utterance=src_line.strip(), 26 | lispress=nbest_lines[0].strip(), 27 | ) 28 | 29 | 30 | def create_onmt_prediction_report( 31 | datum_id_jsonl: str, src_txt: str, ref_txt: str, nbest_txt: str, nbest: int, 32 | ): 33 | prediction_report = [ 34 | build_prediction_report_datum( 35 | datum_id_line=datum_id_line, src_line=src_line, nbest_lines=nbest_lines, 36 | ) 37 | for datum_id_line, src_line, ref_line, nbest_lines in zip( 38 | open(datum_id_jsonl), 39 | open(src_txt), 40 | open(ref_txt), 41 | chunked(open(nbest_txt), nbest), 42 | ) 43 | ] 44 | save_jsonl_file(prediction_report, "predictions.jsonl") 45 | 46 | 47 | def main( 48 | datum_id_jsonl: str, src_txt: str, ref_txt: str, nbest_txt: str, nbest: int, 49 | ) -> None: 50 | """Creates 1-best predictions and saves them to files.""" 51 | create_onmt_prediction_report( 52 | datum_id_jsonl=datum_id_jsonl, 53 | src_txt=src_txt, 54 | ref_txt=ref_txt, 55 | nbest_txt=nbest_txt, 56 | nbest=nbest, 57 | ) 58 | 59 | 60 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 61 | argument_parser.add_argument("--datum_id_jsonl", help="datum ID file") 62 | argument_parser.add_argument("--src_txt", help="source sequence file") 63 | argument_parser.add_argument("--ref_txt", help="target sequence reference file") 64 | argument_parser.add_argument("--nbest_txt", help="onmt_translate output file") 65 | argument_parser.add_argument("--nbest", type=int, help="number of hypos per datum") 66 | 67 | 68 | if __name__ == "__main__": 69 | cmdline_parser = argparse.ArgumentParser( 70 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 71 | ) 72 | add_arguments(cmdline_parser) 73 | args = cmdline_parser.parse_args() 74 | 75 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 76 | main( 77 | datum_id_jsonl=args.datum_id_jsonl, 78 | src_txt=args.src_txt, 79 | ref_txt=args.ref_txt, 80 | nbest_txt=args.nbest_txt, 81 | nbest=args.nbest, 82 | ) 83 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/belief_state_tracker_datum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from dataclasses import dataclass 4 | from typing import Dict, List, Optional 5 | 6 | 7 | @dataclass(frozen=True) 8 | class Slot: 9 | name: str 10 | value: str 11 | 12 | 13 | @dataclass(frozen=True) 14 | class BeliefState: 15 | slots_for_domain: Dict[str, List[Slot]] 16 | 17 | def __eq__(self, other: object) -> bool: 18 | if not isinstance(other, BeliefState): 19 | raise NotImplementedError() 20 | domains = set(self.slots_for_domain.keys()) 21 | if domains != set(other.slots_for_domain.keys()): 22 | return False 23 | for domain in domains: 24 | if self.slots_for_domain[domain] != other.slots_for_domain[domain]: 25 | return False 26 | return True 27 | 28 | 29 | @dataclass(frozen=True) 30 | class BeliefStateTrackerDatum: 31 | """A datum for the belief state tracker. 32 | 33 | It is used as the universal data format for both gold and hypos in the 34 | evaluation script `evaluate_belief_state_predictions.py`. 35 | See factory methods in the script `create_belief_state_tracker_data.py`. 36 | """ 37 | 38 | dialogue_id: str 39 | turn_index: int 40 | belief_state: BeliefState 41 | prev_agent_utterance: Optional[str] = None 42 | curr_user_utterance: Optional[str] = None 43 | 44 | 45 | def pretty_print_belief_state(belief_state: BeliefState) -> str: 46 | return "\n".join( 47 | [ 48 | "{}\t{}".format( 49 | domain, 50 | " | ".join(["{}={}".format(slot.name, slot.value) for slot in slots]), 51 | ) 52 | for domain, slots in sorted( 53 | belief_state.slots_for_domain.items(), key=lambda x: x[0] 54 | ) 55 | ] 56 | ) 57 | 58 | 59 | def sort_slots(slots_for_domain: Dict[str, List[Slot]]): 60 | """Sorts slots for each domain.""" 61 | for domain in slots_for_domain: 62 | slots_for_domain[domain].sort(key=lambda x: (x.name, x.value)) 63 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/create_belief_state_tracker_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Creates BeliefStateTrackerDatum from different sources TRADE processed dialogues. 7 | """ 8 | 9 | import argparse 10 | import json 11 | from typing import Any, Dict, Iterator, List 12 | 13 | from dataflow.core.io_utils import save_jsonl_file 14 | from dataflow.multiwoz.belief_state_tracker_datum import ( 15 | BeliefState, 16 | BeliefStateTrackerDatum, 17 | Slot, 18 | sort_slots, 19 | ) 20 | from dataflow.multiwoz.ontology import DATAFLOW_SLOT_NAMES_FOR_DOMAIN 21 | from dataflow.multiwoz.trade_dst_utils import ( 22 | flatten_belief_state, 23 | get_domain_and_slot_name, 24 | ) 25 | 26 | 27 | def build_belief_state_from_belief_dict( 28 | belief_dict: Dict[str, str], strict: bool 29 | ) -> BeliefState: 30 | slots_for_domain: Dict[str, List[Slot]] = dict() 31 | for slot_fullname, slot_value in belief_dict.items(): 32 | domain, slot_name = get_domain_and_slot_name(slot_fullname) 33 | if strict: 34 | assert ( 35 | slot_name in DATAFLOW_SLOT_NAMES_FOR_DOMAIN[domain] 36 | ), 'slot "{}" is not in ontology for domain "{}"'.format(slot_name, domain) 37 | elif slot_name not in DATAFLOW_SLOT_NAMES_FOR_DOMAIN[domain]: 38 | # NOTE: We only print a warning. The slot will be still included in the 39 | # belief state for evaluation. 40 | # If we assume the Belief State Tracker knows the ontology in advance, then 41 | # we can remove the slot from the prediction. 42 | print( 43 | 'slot "{}" is not in ontology for domain "{}"'.format(slot_name, domain) 44 | ) 45 | if domain not in slots_for_domain: 46 | slots_for_domain[domain] = [] 47 | slots_for_domain[domain].append(Slot(name=slot_name, value=slot_value)) 48 | sort_slots(slots_for_domain) 49 | return BeliefState(slots_for_domain=slots_for_domain) 50 | 51 | 52 | def build_belief_state_from_trade_turn(trade_turn: Dict[str, Any]) -> BeliefState: 53 | """Returns a BeliefState object from a TRADE turn.""" 54 | # do not drop any slots or change any slot values 55 | belief_dict = flatten_belief_state( 56 | belief_state=trade_turn["belief_state"], 57 | keep_all_domains=True, 58 | remove_none=False, 59 | ) 60 | return build_belief_state_from_belief_dict(belief_dict=belief_dict, strict=True) 61 | 62 | 63 | def build_belief_state_tracker_data_from_trade_dialogue( 64 | trade_dialogue: Dict[str, Any], 65 | ) -> Iterator[BeliefStateTrackerDatum]: 66 | for trade_turn in trade_dialogue["dialogue"]: 67 | yield BeliefStateTrackerDatum( 68 | dialogue_id=trade_dialogue["dialogue_idx"], 69 | turn_index=int(trade_turn["turn_idx"]), 70 | belief_state=build_belief_state_from_trade_turn(trade_turn), 71 | prev_agent_utterance=trade_turn["system_transcript"], 72 | curr_user_utterance=trade_turn["transcript"], 73 | ) 74 | 75 | 76 | def main(trade_data_file: str, belief_state_tracker_data_file: str) -> None: 77 | with open(trade_data_file) as fp: 78 | trade_dialogues = json.loads(fp.read().strip()) 79 | belief_state_tracker_data = [ 80 | datum 81 | for trade_dialogue in trade_dialogues 82 | for datum in build_belief_state_tracker_data_from_trade_dialogue(trade_dialogue) 83 | ] 84 | 85 | save_jsonl_file( 86 | data=belief_state_tracker_data, 87 | data_jsonl=belief_state_tracker_data_file, 88 | remove_null=True, 89 | ) 90 | 91 | 92 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 93 | argument_parser.add_argument( 94 | "--trade_data_file", help="TRADE processed dialogues file", 95 | ) 96 | argument_parser.add_argument( 97 | "--belief_state_tracker_data_file", 98 | help="output jsonl file of BeliefStateTrackerDatum", 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | cmdline_parser = argparse.ArgumentParser( 104 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 105 | ) 106 | add_arguments(cmdline_parser) 107 | args = cmdline_parser.parse_args() 108 | 109 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 110 | main( 111 | trade_data_file=args.trade_data_file, 112 | belief_state_tracker_data_file=args.belief_state_tracker_data_file, 113 | ) 114 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/evaluate_belief_state_predictions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Evaluates belief state tracking predictions. 7 | """ 8 | 9 | import argparse 10 | from dataclasses import dataclass, field 11 | from typing import Dict, cast 12 | 13 | import jsons 14 | 15 | from dataflow.core.io_utils import load_jsonl_file_and_build_lookup 16 | from dataflow.multiwoz.create_belief_state_prediction_report import ( 17 | BeliefStatePredictionReportDatum, 18 | ) 19 | from dataflow.multiwoz.ontology import DATAFLOW_SLOT_NAMES_FOR_DOMAIN 20 | 21 | 22 | @dataclass 23 | class EvaluationStats: 24 | num_total_turns: int = 0 25 | num_correct_turns: int = 0 26 | num_correct_turns_after_first_error: int = 0 27 | num_turns_before_first_error: int = 0 28 | # key: slot name (either with domain or without domain) 29 | # value: number of correct turns 30 | num_correct_turns_for_slot: Dict[str, int] = field(default_factory=dict) 31 | num_total_dialogues: int = 0 32 | num_correct_dialogues: int = 0 33 | 34 | @property 35 | def accuracy(self) -> float: 36 | if self.num_total_turns == 0: 37 | return 0 38 | return self.num_correct_turns / self.num_total_turns 39 | 40 | @property 41 | def accuracy_for_slot(self) -> Dict[str, float]: 42 | if self.num_total_turns == 0: 43 | return { 44 | slot_name: 0 for slot_name, _ in self.num_correct_turns_for_slot.items() 45 | } 46 | return { 47 | slot_name: num_correct_turns / self.num_total_turns 48 | for slot_name, num_correct_turns in self.num_correct_turns_for_slot.items() 49 | } 50 | 51 | @property 52 | def ave_num_turns_before_first_error(self) -> float: 53 | if self.num_total_dialogues == 0: 54 | return 0 55 | return self.num_turns_before_first_error / self.num_total_dialogues 56 | 57 | @property 58 | def pct_correct_dialogues(self) -> float: 59 | if self.num_total_dialogues == 0: 60 | return 0 61 | return self.num_correct_dialogues / self.num_total_dialogues 62 | 63 | def __iadd__(self, other: object) -> "EvaluationStats": 64 | if not isinstance(other, EvaluationStats): 65 | raise ValueError() 66 | self.num_total_turns += other.num_total_turns 67 | self.num_correct_turns += other.num_correct_turns 68 | self.num_correct_turns_after_first_error += ( 69 | other.num_correct_turns_after_first_error 70 | ) 71 | self.num_turns_before_first_error += other.num_turns_before_first_error 72 | self.num_correct_turns_for_slot = cast( 73 | Dict[str, int], self.num_correct_turns_for_slot 74 | ) 75 | other.num_correct_turns_for_slot = cast( 76 | Dict[str, int], other.num_correct_turns_for_slot 77 | ) 78 | for (slot_name, num_correct_turns,) in other.num_correct_turns_for_slot.items(): 79 | if slot_name not in self.num_correct_turns_for_slot: 80 | self.num_correct_turns_for_slot[slot_name] = 0 81 | self.num_correct_turns_for_slot[slot_name] += num_correct_turns 82 | self.num_total_dialogues += other.num_total_dialogues 83 | self.num_correct_dialogues += other.num_correct_dialogues 84 | 85 | return self 86 | 87 | def __add__(self, other: object) -> "EvaluationStats": 88 | if not isinstance(other, EvaluationStats): 89 | raise ValueError() 90 | result = EvaluationStats() 91 | result += self 92 | result += other 93 | 94 | return result 95 | 96 | 97 | def evaluate_dialogue( 98 | dialogue: Dict[int, BeliefStatePredictionReportDatum], 99 | ) -> EvaluationStats: 100 | """Evaluates a dialogue.""" 101 | stats = EvaluationStats() 102 | seen_error = False 103 | prediction_report_datum: BeliefStatePredictionReportDatum 104 | for turn_index, prediction_report_datum in sorted( 105 | dialogue.items(), key=lambda x: int(x[0]) 106 | ): 107 | assert turn_index == prediction_report_datum.turn_index 108 | 109 | stats.num_total_turns += 1 110 | 111 | if prediction_report_datum.is_correct: 112 | stats.num_correct_turns += 1 113 | if seen_error: 114 | stats.num_correct_turns_after_first_error += 1 115 | else: 116 | stats.num_turns_before_first_error += 1 117 | else: 118 | seen_error = True 119 | 120 | for domain, slot_names in DATAFLOW_SLOT_NAMES_FOR_DOMAIN.items(): 121 | gold_slots = prediction_report_datum.gold.slots_for_domain.get(domain, []) 122 | gold_slot_value_lookup = {slot.name: slot.value for slot in gold_slots} 123 | assert len(gold_slot_value_lookup) == len(gold_slots) 124 | 125 | hypo_slots = prediction_report_datum.prediction.slots_for_domain.get( 126 | domain, [] 127 | ) 128 | hypo_slot_value_lookup = {slot.name: slot.value for slot in hypo_slots} 129 | assert len(hypo_slot_value_lookup) == len(hypo_slots) 130 | 131 | for slot_name in slot_names: 132 | gold_slot_value = gold_slot_value_lookup.get(slot_name) 133 | hypo_slot_value = hypo_slot_value_lookup.get(slot_name) 134 | 135 | # these two values should be treated as null and the slots should not be presented in the belief state 136 | assert gold_slot_value not in ["", "not mentioned"] 137 | assert hypo_slot_value not in ["", "not mentioned"] 138 | 139 | if gold_slot_value != hypo_slot_value: 140 | continue 141 | 142 | slot_fullname = "{}-{}".format(domain, slot_name) 143 | if slot_fullname not in stats.num_correct_turns_for_slot: 144 | stats.num_correct_turns_for_slot[slot_fullname] = 0 145 | stats.num_correct_turns_for_slot[slot_fullname] += 1 146 | 147 | if not seen_error: 148 | stats.num_correct_dialogues += 1 149 | stats.num_total_dialogues += 1 150 | 151 | return stats 152 | 153 | 154 | def evaluate_dataset( 155 | prediction_report_lookup: Dict[str, Dict[int, BeliefStatePredictionReportDatum]], 156 | ) -> EvaluationStats: 157 | evaluation_stats = EvaluationStats() 158 | for _dialogue_id, dialogue in prediction_report_lookup.items(): 159 | stats = evaluate_dialogue(dialogue=dialogue) 160 | evaluation_stats += stats 161 | 162 | return evaluation_stats 163 | 164 | 165 | def main(prediction_report_jsonl: str, outbase: str) -> str: 166 | prediction_report_lookup = load_jsonl_file_and_build_lookup( 167 | data_jsonl=prediction_report_jsonl, 168 | cls=BeliefStatePredictionReportDatum, 169 | primary_key_getter=lambda x: x.dialogue_id, 170 | secondary_key_getter=lambda x: x.turn_index, 171 | ) 172 | 173 | stats = evaluate_dataset(prediction_report_lookup=prediction_report_lookup) 174 | scores_file = outbase + ".scores.json" 175 | with open(scores_file, "w") as fp: 176 | fp.write(jsons.dumps(stats, {"indent": 2, "sort_keys": True})) 177 | fp.write("\n") 178 | 179 | return scores_file 180 | 181 | 182 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 183 | argument_parser.add_argument( 184 | "--prediction_report_jsonl", help="the prediction report jsonl file" 185 | ) 186 | argument_parser.add_argument("--outbase", help="output files basename") 187 | 188 | 189 | if __name__ == "__main__": 190 | cmdline_parser = argparse.ArgumentParser( 191 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 192 | ) 193 | add_arguments(cmdline_parser) 194 | args = cmdline_parser.parse_args() 195 | 196 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 197 | main(prediction_report_jsonl=args.prediction_report_jsonl, outbase=args.outbase) 198 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/ontology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | TRADE_SLOT_NAMES_FOR_DOMAIN = { 4 | "attraction": ["area", "name", "type"], 5 | "hotel": [ 6 | "area", 7 | "book day", 8 | "book people", 9 | "book stay", 10 | "internet", 11 | "name", 12 | "parking", 13 | "pricerange", 14 | "stars", 15 | "type", 16 | ], 17 | "restaurant": [ 18 | "area", 19 | "book day", 20 | "book people", 21 | "book time", 22 | "food", 23 | "name", 24 | "pricerange", 25 | ], 26 | "taxi": ["arriveby", "departure", "destination", "leaveat"], 27 | "train": ["arriveby", "book people", "day", "departure", "destination", "leaveat"], 28 | "bus": ["day", "departure", "destination", "leaveat"], 29 | "hospital": ["department"], 30 | } 31 | 32 | # The slot names used in dataflow. 33 | # NOTE: We cannot use the original TRADE_SLOT_NAMES_FOR_DOMAIN because space is not allowed in dataflow slot names. 34 | DATAFLOW_SLOT_NAMES_FOR_DOMAIN = { 35 | "attraction": ["area", "name", "type"], 36 | "hotel": [ 37 | "area", 38 | "book-day", 39 | "book-people", 40 | "book-stay", 41 | "internet", 42 | "name", 43 | "parking", 44 | "pricerange", 45 | "stars", 46 | "type", 47 | ], 48 | "restaurant": [ 49 | "area", 50 | "book-day", 51 | "book-people", 52 | "book-time", 53 | "food", 54 | "name", 55 | "pricerange", 56 | ], 57 | "taxi": ["arriveby", "departure", "destination", "leaveat"], 58 | "train": ["arriveby", "book-people", "day", "departure", "destination", "leaveat"], 59 | "bus": ["day", "departure", "destination", "leaveat"], 60 | "hospital": ["department"], 61 | } 62 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/patch_trade_dialogues.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Patches TRADE-processed dialogues. 7 | 8 | In TRADE, there are extra steps to fix belief state labels after the data are dumped from `create_data.py`. 9 | It makes the evaluation and comparison difficult b/c those label correction and evaluation are embedded in the training 10 | code rather than separate CLI scripts. 11 | This new script applies the TRADE label corrections (fix_general_label_errors) and re-dumps the dialogues in the same format: 12 | 13 | NOTE: This only patches the "belief_state". Other fields including "turn_label" are unchanged. Thus, there can be 14 | inconsistency between "belief_state" and "turn_label". 15 | """ 16 | import argparse 17 | import json 18 | from typing import Dict, List, Tuple 19 | 20 | from dataflow.multiwoz.ontology import TRADE_SLOT_NAMES_FOR_DOMAIN 21 | from dataflow.multiwoz.trade_dst_utils import ( 22 | fix_general_label_error, 23 | get_domain_and_slot_name, 24 | ) 25 | 26 | 27 | def validate_belief_dict(trade_belief_dict: Dict[str, str]) -> List[str]: 28 | """Validates the belief dict and returns the domains with all "none" values.""" 29 | slots_lookup: Dict[str, Dict[str, str]] = {} 30 | for slot_fullname, slot_value in trade_belief_dict.items(): 31 | # slot value should not be "empty" or "not mentioned" 32 | assert slot_value not in ["", "not mentioned"] 33 | 34 | domain, slot_name = get_domain_and_slot_name(slot_fullname) 35 | assert slot_name in TRADE_SLOT_NAMES_FOR_DOMAIN[domain] 36 | 37 | if domain not in slots_lookup: 38 | slots_lookup[domain] = {} 39 | slots_lookup[domain][slot_name] = slot_value 40 | 41 | # active domains should have at least one slot that is not "none" 42 | all_none_domains = [] 43 | for domain, slots in slots_lookup.items(): 44 | if all([slot_value == "none" for slot_value in slots.values()]): 45 | all_none_domains.append(domain) 46 | return all_none_domains 47 | 48 | 49 | def main(trade_data_file: str, outbase: str) -> Tuple[str, str]: 50 | trade_dialogues = json.load(open(trade_data_file, "r")) 51 | # turns that need manual review 52 | need_review_turns = [] 53 | for trade_dialogue in trade_dialogues: 54 | for trade_turn in trade_dialogue["dialogue"]: 55 | trade_belief_dict = fix_general_label_error(trade_turn["belief_state"]) 56 | for item in trade_turn["belief_state"]: 57 | assert item["act"] == "inform" 58 | all_none_domains = validate_belief_dict(trade_belief_dict) 59 | if all_none_domains: 60 | is_last_turn = int(trade_turn["turn_idx"]) + 1 == len( 61 | trade_dialogue["dialogue"] 62 | ) 63 | need_review_turns.append( 64 | { 65 | "dialogueId": trade_dialogue["dialogue_idx"], 66 | "turnIndex": trade_turn["turn_idx"], 67 | "isLastTurn": is_last_turn, 68 | "prevAgentUtterance": trade_turn["system_transcript"], 69 | "currUserUtterance": trade_turn["transcript"], 70 | "beliefDict": trade_belief_dict, 71 | "allNoneDomains": all_none_domains, 72 | } 73 | ) 74 | 75 | trade_turn["belief_state"] = [ 76 | {"slots": [[slot_fullname, slot_value]], "act": "inform"} 77 | for slot_fullname, slot_value in trade_belief_dict.items() 78 | ] 79 | 80 | patched_dials_file = outbase + "_dials.json" 81 | with open(patched_dials_file, "w") as fp: 82 | json.dump(trade_dialogues, fp, indent=4) 83 | 84 | need_review_turns_file = outbase + "_need_review_turns.json" 85 | with open(need_review_turns_file, "w") as fp: 86 | json.dump(need_review_turns, fp, indent=2) 87 | 88 | return patched_dials_file, need_review_turns_file 89 | 90 | 91 | if __name__ == "__main__": 92 | cmdline_parser = argparse.ArgumentParser( 93 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 94 | ) 95 | cmdline_parser.add_argument("--trade_data_file", help="the trade data file") 96 | cmdline_parser.add_argument("--outbase", help="output files base name") 97 | args = cmdline_parser.parse_args() 98 | 99 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 100 | main(trade_data_file=args.trade_data_file, outbase=args.outbase) 101 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/salience_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass 5 | from typing import Dict, List, Optional, Set, Tuple 6 | 7 | 8 | @dataclass(frozen=True) 9 | class PartialExecutionResult: 10 | """A partial execution result for an express at a turn. 11 | """ 12 | 13 | # the values in the dataflow graph 14 | # - key: expression ID 15 | # - value: the underlying value, None means removing the slot from the belief state 16 | values: Dict[str, Optional[str]] 17 | # the constraints in the dataflow graph 18 | constraints: Dict[str, str] 19 | # the refer call logs 20 | # - key: the expression ID that uses the refer call 21 | # - value: the target type 22 | refer_calls: Dict[str, str] 23 | # the typed slot values 24 | # - key: the slot name (without the domain) 25 | # - value: the history of (slotValue, expressionId) for this type (slotValue==None means deletion) 26 | slot_values: Dict[str, List[Tuple[Optional[str], Optional[str]]]] 27 | 28 | 29 | @dataclass(frozen=True) 30 | class ExecutionTrace: 31 | """A complete execution trace for a dialogue up to a turn.""" 32 | 33 | # - key: slot name without the domain 34 | # - value: the history of (slotValue, expressionId) for this type (slotValue==None means deletion) 35 | slot_values: Dict[str, List[Tuple[Optional[str], Optional[str]]]] 36 | 37 | 38 | class SalienceModelBase(ABC): 39 | """An abstract class for the salience model. 40 | """ 41 | 42 | @abstractmethod 43 | def get_salient_value( 44 | self, 45 | target_type: str, 46 | execution_trace: ExecutionTrace, 47 | exclude_values: Set[str], 48 | ) -> Optional[str]: 49 | """Gets the salient mention from dialogue context. 50 | 51 | Args: 52 | target_type: the target type of the salient value 53 | execution_trace: the execution_trace up to the previous turn 54 | exclude_values: values should not be returned 55 | Returns: 56 | The retrieved salient value for the target type. 57 | """ 58 | raise NotImplementedError() 59 | 60 | 61 | class DummySalienceModel(SalienceModelBase): 62 | """A dummy salience model which always returns None. 63 | """ 64 | 65 | def get_salient_value( 66 | self, 67 | target_type: str, 68 | execution_trace: ExecutionTrace, 69 | exclude_values: Set[str], 70 | ) -> Optional[str]: 71 | """See base class. 72 | 73 | For a dummy salience model, we always return None. 74 | """ 75 | return None 76 | 77 | 78 | class VanillaSalienceModel(SalienceModelBase): 79 | """A vanilla salience model. 80 | """ 81 | 82 | # the ontology for slot types 83 | # it records the compatible slot names for salience calls 84 | # the lookup from slot name to slot type 85 | SLOT_TYPE_ONTOLOGY: Dict[str, List[str]] = { 86 | "PLACE": ["name", "destination", "departure"], 87 | "DAY": ["day", "book-day"], 88 | "TIME": ["book-time", "arriveby", "leaveat"], 89 | "NUMBER": ["book-people", "stars", "book-stay"], 90 | } 91 | SLOT_TYPE_LOOKUP: Dict[str, str] = { 92 | slot_name: slot_type 93 | for slot_type, slot_names in SLOT_TYPE_ONTOLOGY.items() 94 | for slot_name in slot_names 95 | } 96 | SLOT_VALUE_BLOCKLIST: Set[Optional[str]] = {"none", "dontcare", None} 97 | 98 | def get_salient_value( 99 | self, 100 | target_type: str, 101 | execution_trace: ExecutionTrace, 102 | exclude_values: Set[str], 103 | ) -> Optional[str]: 104 | """See base class. 105 | 106 | Currently, this method returns the most recent occurrence of the value that is compatible with the target type. 107 | """ 108 | for value, _ in reversed(execution_trace.slot_values.get(target_type, [])): 109 | if value in exclude_values: 110 | continue 111 | if value in self.SLOT_VALUE_BLOCKLIST: 112 | continue 113 | return value 114 | 115 | slot_type = self.SLOT_TYPE_LOOKUP.get(target_type, None) 116 | if slot_type is not None: 117 | for slot_name in self.SLOT_TYPE_ONTOLOGY[slot_type]: 118 | for value, _ in reversed( 119 | execution_trace.slot_values.get(slot_name, []) 120 | ): 121 | if value in exclude_values: 122 | continue 123 | if value in self.SLOT_VALUE_BLOCKLIST: 124 | continue 125 | return value 126 | 127 | return None 128 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/trade_dst/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/trade_dst/mapping.pair: -------------------------------------------------------------------------------- 1 | it's it is 2 | don't do not 3 | doesn't does not 4 | didn't did not 5 | you'd you would 6 | you're you are 7 | you'll you will 8 | i'm i am 9 | they're they are 10 | that's that is 11 | what's what is 12 | couldn't could not 13 | i've i have 14 | we've we have 15 | can't cannot 16 | i'd i would 17 | i'd i would 18 | aren't are not 19 | isn't is not 20 | wasn't was not 21 | weren't were not 22 | won't will not 23 | there's there is 24 | there're there are 25 | . . . 26 | restaurants restaurant -s 27 | hotels hotel -s 28 | laptops laptop -s 29 | cheaper cheap -er 30 | dinners dinner -s 31 | lunches lunch -s 32 | breakfasts breakfast -s 33 | expensively expensive -ly 34 | moderately moderate -ly 35 | cheaply cheap -ly 36 | prices price -s 37 | places place -s 38 | venues venue -s 39 | ranges range -s 40 | meals meal -s 41 | locations location -s 42 | areas area -s 43 | policies policy -s 44 | children child -s 45 | kids kid -s 46 | kidfriendly kid friendly 47 | cards card -s 48 | upmarket expensive 49 | inpricey cheap 50 | inches inch -s 51 | uses use -s 52 | dimensions dimension -s 53 | driverange drive range 54 | includes include -s 55 | computers computer -s 56 | machines machine -s 57 | families family -s 58 | ratings rating -s 59 | constraints constraint -s 60 | pricerange price range 61 | batteryrating battery rating 62 | requirements requirement -s 63 | drives drive -s 64 | specifications specification -s 65 | weightrange weight range 66 | harddrive hard drive 67 | batterylife battery life 68 | businesses business -s 69 | hours hour -s 70 | one 1 71 | two 2 72 | three 3 73 | four 4 74 | five 5 75 | six 6 76 | seven 7 77 | eight 8 78 | nine 9 79 | ten 10 80 | eleven 11 81 | twelve 12 82 | anywhere any where 83 | good bye goodbye 84 | -------------------------------------------------------------------------------- /src/dataflow/multiwoz/trade_dst_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import re 4 | from typing import Any, Dict, List, Tuple, Union 5 | 6 | _Slots = List[List[str]] 7 | _Act = str 8 | BeliefState = List[Dict[str, Union[_Slots, _Act]]] 9 | 10 | # From https://github.com/jasonwu0731/trade-dst/blob/master/utils/utils_multiWOZ_DST.py. 11 | _EXPERIMENT_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] 12 | 13 | 14 | def concatenate_system_and_user_transcript(turn: Dict[str, Any]) -> str: 15 | """Returns the concatenated (agentUtterance, userUtterance).""" 16 | return "{} {}".format(turn["system_transcript"].strip(), turn["transcript"].strip()) 17 | 18 | 19 | def get_domain_and_slot_name(slot_fullname: str) -> Tuple[str, str]: 20 | """Returns the domain in a slot fullname.""" 21 | units = slot_fullname.split("-") 22 | return units[0], "-".join(units[1:]) 23 | 24 | 25 | def trade_normalize_slot_name(name: str) -> str: 26 | """Normalizes the slot name as in TRADE. 27 | 28 | Extracted from get_slot_information in https://github.com/jasonwu0731/trade-dst/blob/master/utils/utils_multiWOZ_DST.py. 29 | """ 30 | if "book" not in name: 31 | return name.replace(" ", "").lower() 32 | return name.lower() 33 | 34 | 35 | def fix_general_label_error(trade_belief_state_labels: BeliefState) -> Dict[str, str]: 36 | """Fixes some label errors in MultiWoZ. 37 | 38 | Adapted from https://github.com/jasonwu0731/trade-dst/blob/master/utils/fix_label.py. 39 | - Removed the "type" argument, which is always False. 40 | - "ALL_SLOTS" are hard-coded in the method now 41 | 42 | NOTE: the trade_belief_state_labels may not share the same slot name as dataflow. 43 | 44 | Args: 45 | trade_belief_state_labels: TRADE processed original belief state 46 | Returns: 47 | the flatten belief dictionary with corrected slot values 48 | """ 49 | label_dict: Dict[str, str] = { 50 | l["slots"][0][0]: l["slots"][0][1] for l in trade_belief_state_labels 51 | } 52 | 53 | # hard-coded list of slot names extracted from `get_slot_information` using the ontology.json file 54 | all_slots: List[str] = [ 55 | "hotel-pricerange", 56 | "hotel-type", 57 | "hotel-parking", 58 | "hotel-book stay", 59 | "hotel-book day", 60 | "hotel-book people", 61 | "hotel-area", 62 | "hotel-stars", 63 | "hotel-internet", 64 | "hotel-name", 65 | "train-destination", 66 | "train-day", 67 | "train-departure", 68 | "train-arriveby", 69 | "train-book people", 70 | "train-leaveat", 71 | "restaurant-food", 72 | "restaurant-pricerange", 73 | "restaurant-area", 74 | "restaurant-name", 75 | "attraction-area", 76 | "attraction-name", 77 | "attraction-type", 78 | "taxi-leaveat", 79 | "taxi-destination", 80 | "taxi-departure", 81 | "taxi-arriveby", 82 | "restaurant-book time", 83 | "restaurant-book day", 84 | "restaurant-book people", 85 | ] 86 | 87 | general_typo = { 88 | # type 89 | "guesthouse": "guest house", 90 | "guesthouses": "guest house", 91 | "guest": "guest house", 92 | "mutiple sports": "multiple sports", 93 | "sports": "multiple sports", 94 | "mutliple sports": "multiple sports", 95 | "swimmingpool": "swimming pool", 96 | "concerthall": "concert hall", 97 | "concert": "concert hall", 98 | "pool": "swimming pool", 99 | "night club": "nightclub", 100 | "mus": "museum", 101 | "ol": "architecture", 102 | "colleges": "college", 103 | "coll": "college", 104 | "architectural": "architecture", 105 | "musuem": "museum", 106 | "churches": "church", 107 | # area 108 | "center": "centre", 109 | "center of town": "centre", 110 | "near city center": "centre", 111 | "in the north": "north", 112 | "cen": "centre", 113 | "east side": "east", 114 | "east area": "east", 115 | "west part of town": "west", 116 | "ce": "centre", 117 | "town center": "centre", 118 | "centre of cambridge": "centre", 119 | "city center": "centre", 120 | "the south": "south", 121 | "scentre": "centre", 122 | "town centre": "centre", 123 | "in town": "centre", 124 | "north part of town": "north", 125 | "centre of town": "centre", 126 | "cb30aq": "none", 127 | # price 128 | "mode": "moderate", 129 | "moderate -ly": "moderate", 130 | "mo": "moderate", 131 | # day 132 | "next friday": "friday", 133 | "monda": "monday", 134 | # parking 135 | "free parking": "free", 136 | # internet 137 | "free internet": "yes", 138 | # star 139 | "4 star": "4", 140 | "4 stars": "4", 141 | "0 star rarting": "none", 142 | # others 143 | "y": "yes", 144 | "any": "dontcare", 145 | "n": "no", 146 | "does not care": "dontcare", 147 | "not men": "none", 148 | "not": "none", 149 | "not mentioned": "none", 150 | "": "none", 151 | "not mendtioned": "none", 152 | "3 .": "3", 153 | "does not": "no", 154 | "fun": "none", 155 | "art": "none", 156 | } 157 | 158 | # pylint: disable=too-many-boolean-expressions 159 | for slot in all_slots: 160 | if slot in label_dict.keys(): 161 | # general typos 162 | if label_dict[slot] in general_typo.keys(): 163 | label_dict[slot] = label_dict[slot].replace( 164 | label_dict[slot], general_typo[label_dict[slot]] 165 | ) 166 | 167 | # miss match slot and value 168 | if ( 169 | slot == "hotel-type" 170 | and label_dict[slot] 171 | in [ 172 | "nigh", 173 | "moderate -ly priced", 174 | "bed and breakfast", 175 | "centre", 176 | "venetian", 177 | "intern", 178 | "a cheap -er hotel", 179 | ] 180 | or slot == "hotel-internet" 181 | and label_dict[slot] == "4" 182 | or slot == "hotel-pricerange" 183 | and label_dict[slot] == "2" 184 | or slot == "attraction-type" 185 | and label_dict[slot] 186 | in ["gastropub", "la raza", "galleria", "gallery", "science", "m"] 187 | or "area" in slot 188 | and label_dict[slot] in ["moderate"] 189 | or "day" in slot 190 | and label_dict[slot] == "t" 191 | ): 192 | label_dict[slot] = "none" 193 | elif slot == "hotel-type" and label_dict[slot] in [ 194 | "hotel with free parking and free wifi", 195 | "4", 196 | "3 star hotel", 197 | ]: 198 | label_dict[slot] = "hotel" 199 | elif slot == "hotel-star" and label_dict[slot] == "3 star hotel": 200 | label_dict[slot] = "3" 201 | elif "area" in slot: 202 | if label_dict[slot] == "no": 203 | label_dict[slot] = "north" 204 | elif label_dict[slot] == "we": 205 | label_dict[slot] = "west" 206 | elif label_dict[slot] == "cent": 207 | label_dict[slot] = "centre" 208 | elif "day" in slot: 209 | if label_dict[slot] == "we": 210 | label_dict[slot] = "wednesday" 211 | elif label_dict[slot] == "no": 212 | label_dict[slot] = "none" 213 | elif "price" in slot and label_dict[slot] == "ch": 214 | label_dict[slot] = "cheap" 215 | elif "internet" in slot and label_dict[slot] == "free": 216 | label_dict[slot] = "yes" 217 | 218 | # some out-of-define classification slot values 219 | if ( 220 | slot == "restaurant-area" 221 | and label_dict[slot] 222 | in ["stansted airport", "cambridge", "silver street"] 223 | or slot == "attraction-area" 224 | and label_dict[slot] 225 | in ["norwich", "ely", "museum", "same area as hotel"] 226 | ): 227 | label_dict[slot] = "none" 228 | 229 | return label_dict 230 | 231 | 232 | def normalize_trade_slot_name(name: str) -> str: 233 | """Normalizes the TRADE slot name to the dataflow slot name. 234 | 235 | Replace whitespace to make it easier to tokenize plans. 236 | """ 237 | return re.sub(r"(\s)+", "-", name) 238 | 239 | 240 | def flatten_belief_state( 241 | belief_state: BeliefState, keep_all_domains: bool, remove_none: bool 242 | ) -> Dict[str, str]: 243 | """Converts the belief state into a flatten dictionary. 244 | 245 | Args: 246 | belief_state: the TRADE belief state 247 | keep_all_domains: True if we keep all domains in the belief state; False if we only keep TRADE experiment domains 248 | remove_none: True if we remove slots with "none" value from the returned belief dict 249 | Returns: 250 | the flatten belief state dictionary 251 | """ 252 | trade_belief_dict: Dict[str, str] = { 253 | item["slots"][0][0]: item["slots"][0][1] for item in belief_state 254 | } 255 | return { 256 | normalize_trade_slot_name(name=slot_fullname): slot_value 257 | for slot_fullname, slot_value in trade_belief_dict.items() 258 | if (not remove_none or slot_value != "none") 259 | and ( 260 | keep_all_domains 261 | or get_domain_and_slot_name(slot_fullname)[0] in _EXPERIMENT_DOMAINS 262 | ) 263 | } 264 | -------------------------------------------------------------------------------- /src/dataflow/onmt_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /src/dataflow/onmt_helpers/compute_onmt_data_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Computes statistics on the data created by create_onmt_text_data. 7 | """ 8 | import argparse 9 | import json 10 | import os 11 | import typing 12 | from collections import defaultdict 13 | from typing import List 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | 19 | def compute_num_examples(input_txt: str) -> int: 20 | """Computes number of examples in a file. 21 | 22 | Simply count lines since each line is an example. 23 | """ 24 | count = 0 25 | for _ in open(input_txt): 26 | count += 1 27 | return count 28 | 29 | 30 | def compute_ntokens_percentiles( 31 | input_txt_files: List[str], percentiles: List[int] 32 | ) -> List[int]: 33 | """Computes the percentiles of sequence lengths.""" 34 | ntokens_array: List[int] = sum( 35 | [ 36 | [len(line.strip().split()) for line in open(input_txt)] 37 | for input_txt in input_txt_files 38 | ], 39 | [], 40 | ) 41 | return np.percentile(ntokens_array, percentiles).tolist() 42 | 43 | 44 | def save_ntokens_percentiles( 45 | percentiles: List[int], values: List[int], stats_tsv: str, stats_json: str 46 | ) -> None: 47 | """Saves the sequence length percentiles to a tsv file.""" 48 | stats_df = pd.DataFrame({"percentile": percentiles, "value": values}) 49 | stats_df.to_csv(stats_tsv, sep="\t", index=False, encoding="utf-8") 50 | stats_dict = dict(zip(percentiles, values)) 51 | with open(stats_json, "w") as fp: 52 | fp.write(json.dumps(stats_dict, indent=2)) 53 | fp.write("\n") 54 | 55 | 56 | def compute_token_occurrences( 57 | input_txt_files: List[str], 58 | ) -> typing.DefaultDict[str, int]: 59 | """Computes token occurrences.""" 60 | token_counter: typing.DefaultDict[str, int] = defaultdict(int) 61 | for input_txt in input_txt_files: 62 | for line in open(input_txt): 63 | for token in line.strip().split(): 64 | token_counter[token] += 1 65 | return token_counter 66 | 67 | 68 | def save_token_occurrences( 69 | token_counter: typing.DefaultDict[str, int], counts_tsv: str 70 | ) -> None: 71 | counts_df = pd.DataFrame( 72 | [ 73 | {"token": token, "count": count} 74 | for token, count in sorted(token_counter.items(), key=lambda x: -x[1]) 75 | ] 76 | ) 77 | counts_df.to_csv(counts_tsv, sep="\t", index=False, encoding="utf-8") 78 | 79 | 80 | def main( 81 | text_data_dir: str, subsets: List[str], suffixes: List[str], outdir: str, 82 | ) -> None: 83 | if not os.path.exists(outdir): 84 | os.mkdir(outdir) 85 | 86 | # =============== 87 | # nexamples stats 88 | # =============== 89 | nexamples_lookup = { 90 | subset: compute_num_examples( 91 | os.path.join(text_data_dir, "{}.datum_id".format(subset)) 92 | ) 93 | for subset in subsets 94 | } 95 | with open(os.path.join(outdir, "nexamples.json"), "w") as fp: 96 | json.dump(nexamples_lookup, fp, indent=2) 97 | 98 | # =============== 99 | # ntokens stats 100 | # =============== 101 | percentiles = list(range(0, 101, 10)) 102 | for suffix in suffixes: 103 | for subset in subsets: 104 | input_txt = os.path.join(text_data_dir, f"{subset}.{suffix}") 105 | outbase = os.path.join(outdir, f"{subset}.{suffix}") 106 | 107 | values = compute_ntokens_percentiles( 108 | input_txt_files=[input_txt], percentiles=percentiles 109 | ) 110 | save_ntokens_percentiles( 111 | percentiles=percentiles, 112 | values=values, 113 | stats_tsv=f"{outbase}.ntokens_stats.tsv", 114 | stats_json=f"{outbase}.ntokens_stats.json", 115 | ) 116 | counter = compute_token_occurrences(input_txt_files=[input_txt]) 117 | save_token_occurrences( 118 | token_counter=counter, counts_tsv=f"{outbase}.token_count.tsv" 119 | ) 120 | 121 | if len(subsets) > 1: 122 | input_txt_files = [ 123 | os.path.join(text_data_dir, f"{subset}.{suffix}") for subset in subsets 124 | ] 125 | outbase = os.path.join(outdir, "{}.{}".format("-".join(subsets), suffix)) 126 | 127 | values = compute_ntokens_percentiles( 128 | input_txt_files=input_txt_files, percentiles=percentiles 129 | ) 130 | save_ntokens_percentiles( 131 | percentiles=percentiles, 132 | values=values, 133 | stats_tsv=f"{outbase}.ntokens_stats.tsv", 134 | stats_json=f"{outbase}.ntokens_stats.json", 135 | ) 136 | 137 | counter = compute_token_occurrences(input_txt_files=input_txt_files) 138 | save_token_occurrences( 139 | token_counter=counter, counts_tsv=f"{outbase}.token_count.tsv" 140 | ) 141 | 142 | 143 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 144 | argument_parser.add_argument("--text_data_dir", help="the text data directory") 145 | argument_parser.add_argument( 146 | "--suffix", 147 | nargs="+", 148 | default=["src", "src_tok", "tgt"], 149 | help="the suffix to be analyzed", 150 | ) 151 | argument_parser.add_argument( 152 | "--subset", nargs="+", help="the subset to be analyzed" 153 | ) 154 | argument_parser.add_argument("--outdir", help="the output directory") 155 | 156 | 157 | if __name__ == "__main__": 158 | cmdline_parser = argparse.ArgumentParser( 159 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 160 | ) 161 | add_arguments(cmdline_parser) 162 | args = cmdline_parser.parse_args() 163 | 164 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 165 | main( 166 | text_data_dir=args.text_data_dir, 167 | subsets=args.subset, 168 | suffixes=args.suffix, 169 | outdir=args.outdir, 170 | ) 171 | -------------------------------------------------------------------------------- /src/dataflow/onmt_helpers/create_onmt_prediction_report.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Creates the prediction report from onmt_translate output. 7 | """ 8 | import argparse 9 | import dataclasses 10 | from dataclasses import dataclass 11 | from typing import Dict, Iterator, List, Union 12 | 13 | import jsons 14 | from more_itertools import chunked 15 | 16 | from dataflow.core.dialogue import ( 17 | AgentUtterance, 18 | Dialogue, 19 | ProgramExecutionOracle, 20 | Turn, 21 | TurnId, 22 | UserUtterance, 23 | ) 24 | from dataflow.core.io_utils import ( 25 | load_jsonl_file, 26 | load_jsonl_file_and_build_lookup, 27 | save_jsonl_file, 28 | ) 29 | from dataflow.core.linearize import seq_to_lispress, to_canonical_form 30 | from dataflow.core.lispress import render_compact 31 | from dataflow.core.prediction_report import ( 32 | PredictionReportDatum, 33 | save_prediction_report_tsv, 34 | save_prediction_report_txt, 35 | ) 36 | 37 | _DUMMY_USER_UTTERANCE = UserUtterance(original_text="", tokens=[]) 38 | _DUMMY_AGENT_UTTERANCE = AgentUtterance( 39 | original_text="", tokens=[], described_entities=[] 40 | ) 41 | _PARSE_ERROR_LISPRESS = '(parseError #(InvalidLispress "")' 42 | 43 | 44 | @dataclass(frozen=True) 45 | class OnmtPredictionReportDatum(PredictionReportDatum): 46 | datum_id: TurnId 47 | source: str 48 | # The tokenized gold lispress. 49 | gold: str 50 | # The tokenized predicted lispress. 51 | prediction: str 52 | program_execution_oracle: ProgramExecutionOracle 53 | 54 | @property 55 | def gold_canonical(self) -> str: 56 | return to_canonical_form(self.gold) 57 | 58 | @property 59 | def prediction_canonical(self) -> str: 60 | try: 61 | return to_canonical_form(self.prediction) 62 | except Exception: # pylint: disable=W0703 63 | return _PARSE_ERROR_LISPRESS 64 | 65 | @property 66 | def is_correct_ignoring_refer(self) -> bool: 67 | return self.gold == self.prediction 68 | 69 | @property 70 | def is_correct(self) -> bool: 71 | return ( 72 | self.is_correct_ignoring_refer 73 | and self.program_execution_oracle.refer_are_correct 74 | ) 75 | 76 | @property 77 | def is_correct_leaderboard_ignoring_refer(self) -> bool: 78 | """Returns true if the gold and the prediction match after canonicalization, ignoring the `refer_are_correct`. 79 | 80 | Use this metric is you only care about the accuracy of the semantic parser, ignoring the execution of `refer`. 81 | """ 82 | return self.gold_canonical == self.prediction_canonical 83 | 84 | @property 85 | def is_correct_leaderboard(self) -> bool: 86 | """Returns true if the gold and the prediction match after canonicalization. 87 | 88 | This is the metric used in the leaderboard, which would be slightly higher than the one reported in the TACL2020 89 | paper, since the named arguments are sorted after canonicalization. 90 | """ 91 | return ( 92 | self.is_correct_leaderboard_ignoring_refer 93 | and self.program_execution_oracle.refer_are_correct 94 | ) 95 | 96 | def flatten_datum_id(self) -> Dict[str, Union[str, int]]: 97 | return { 98 | "dialogueId": self.datum_id.dialogue_id, 99 | "turnIndex": self.datum_id.turn_index, 100 | } 101 | 102 | def flatten(self) -> Dict[str, Union[str, int]]: 103 | flatten_datum_dict = self.flatten_datum_id() 104 | # It's fine to call update since we always return a new dict from self.flatten_datum_id(). 105 | flatten_datum_dict.update( 106 | { 107 | "source": self.source, 108 | "gold": self.gold, 109 | "prediction": self.prediction, 110 | "goldCanonical": self.gold_canonical, 111 | "predictionCanonical": self.prediction_canonical, 112 | "oracleResolveAreCorrect": self.program_execution_oracle.refer_are_correct, 113 | "isCorrect": self.is_correct, 114 | "isCorrectIgnoringRefer": self.is_correct_ignoring_refer, 115 | "isCorrectLeaderboard": self.is_correct_leaderboard, 116 | "isCorrectLeaderboardIgnoringRefer": self.is_correct_leaderboard_ignoring_refer, 117 | } 118 | ) 119 | return flatten_datum_dict 120 | 121 | 122 | def build_prediction_report_datum( 123 | datum_lookup: Dict[str, Dict[int, Turn]], 124 | datum_id_line: str, 125 | src_line: str, 126 | ref_line: str, 127 | nbest_lines: List[str], 128 | ) -> OnmtPredictionReportDatum: 129 | datum_id = jsons.loads(datum_id_line.strip(), TurnId) 130 | datum = datum_lookup[datum_id.dialogue_id][datum_id.turn_index] 131 | return OnmtPredictionReportDatum( 132 | datum_id=datum_id, 133 | source=src_line.strip(), 134 | gold=ref_line.strip(), 135 | prediction=nbest_lines[0].strip(), 136 | program_execution_oracle=datum.program_execution_oracle, 137 | ) 138 | 139 | 140 | def create_onmt_prediction_report( 141 | datum_lookup: Dict[str, Dict[int, Turn]], 142 | datum_id_jsonl: str, 143 | src_txt: str, 144 | ref_txt: str, 145 | nbest_txt: str, 146 | nbest: int, 147 | outbase: str, 148 | ) -> str: 149 | prediction_report = [ 150 | build_prediction_report_datum( 151 | datum_lookup=datum_lookup, 152 | datum_id_line=datum_id_line, 153 | src_line=src_line, 154 | ref_line=ref_line, 155 | nbest_lines=nbest_lines, 156 | ) 157 | for datum_id_line, src_line, ref_line, nbest_lines in zip( 158 | open(datum_id_jsonl), 159 | open(src_txt), 160 | open(ref_txt), 161 | chunked(open(nbest_txt), nbest), 162 | ) 163 | ] 164 | prediction_report.sort(key=lambda x: dataclasses.astuple(x.datum_id)) 165 | predictions_jsonl = f"{outbase}.prediction_report.jsonl" 166 | save_jsonl_file(prediction_report, predictions_jsonl) 167 | save_prediction_report_tsv(prediction_report, f"{outbase}.prediction_report.tsv") 168 | save_prediction_report_txt( 169 | prediction_report=prediction_report, 170 | prediction_report_txt=f"{outbase}.prediction_report.txt", 171 | field_names=[ 172 | "dialogueId", 173 | "turnIndex", 174 | "source", 175 | "oracleResolveAreCorrect", 176 | "isCorrect", 177 | "isCorrectLeaderboard", 178 | "gold", 179 | "prediction", 180 | "goldCanonical", 181 | "predictionCanonical", 182 | ], 183 | ) 184 | return predictions_jsonl 185 | 186 | 187 | def build_dataflow_dialogue( 188 | dialogue_id: str, prediction_report_data: Dict[int, OnmtPredictionReportDatum] 189 | ) -> Dialogue: 190 | turns: List[Turn] = [] 191 | datum: OnmtPredictionReportDatum 192 | for turn_index, datum in sorted(prediction_report_data.items(), key=lambda x: x[0]): 193 | # pylint: disable=broad-except 194 | tokenized_lispress = datum.prediction.split(" ") 195 | try: 196 | lispress = render_compact(seq_to_lispress(tokenized_lispress)) 197 | except Exception as e: 198 | print(e) 199 | lispress = _PARSE_ERROR_LISPRESS 200 | 201 | turns.append( 202 | Turn( 203 | turn_index=turn_index, 204 | user_utterance=_DUMMY_USER_UTTERANCE, 205 | agent_utterance=_DUMMY_AGENT_UTTERANCE, 206 | lispress=lispress, 207 | skip=False, 208 | program_execution_oracle=None, 209 | ) 210 | ) 211 | 212 | return Dialogue(dialogue_id=dialogue_id, turns=turns) 213 | 214 | 215 | def build_dataflow_dialogues( 216 | prediction_report_data_lookup: Dict[str, Dict[int, OnmtPredictionReportDatum]] 217 | ) -> Iterator[Dialogue]: 218 | for dialogue_id, prediction_report_data in prediction_report_data_lookup.items(): 219 | dataflow_dialogue = build_dataflow_dialogue( 220 | dialogue_id=dialogue_id, prediction_report_data=prediction_report_data 221 | ) 222 | yield dataflow_dialogue 223 | 224 | 225 | def main( 226 | dialogues_jsonl: str, 227 | datum_id_jsonl: str, 228 | src_txt: str, 229 | ref_txt: str, 230 | nbest_txt: str, 231 | nbest: int, 232 | outbase: str, 233 | ) -> None: 234 | """Creates 1-best predictions and saves them to files.""" 235 | datum_lookup: Dict[str, Dict[int, Turn]] = { 236 | dialogue.dialogue_id: {turn.turn_index: turn for turn in dialogue.turns} 237 | for dialogue in load_jsonl_file( 238 | data_jsonl=dialogues_jsonl, cls=Dialogue, unit=" dialogues" 239 | ) 240 | } 241 | 242 | prediction_report_jsonl = create_onmt_prediction_report( 243 | datum_lookup=datum_lookup, 244 | datum_id_jsonl=datum_id_jsonl, 245 | src_txt=src_txt, 246 | ref_txt=ref_txt, 247 | nbest_txt=nbest_txt, 248 | nbest=nbest, 249 | outbase=outbase, 250 | ) 251 | 252 | predictions_lookup = load_jsonl_file_and_build_lookup( 253 | data_jsonl=prediction_report_jsonl, 254 | cls=OnmtPredictionReportDatum, 255 | primary_key_getter=lambda x: x.datum_id.dialogue_id, 256 | secondary_key_getter=lambda x: x.datum_id.turn_index, 257 | ) 258 | dataflow_dialogues = build_dataflow_dialogues(predictions_lookup) 259 | save_jsonl_file(dataflow_dialogues, f"{outbase}.dataflow_dialogues.jsonl") 260 | 261 | 262 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 263 | argument_parser.add_argument( 264 | "--dialogues_jsonl", 265 | help="the jsonl file containing the dialogue data with dataflow programs", 266 | ) 267 | argument_parser.add_argument("--datum_id_jsonl", help="datum ID file") 268 | argument_parser.add_argument("--src_txt", help="source sequence file") 269 | argument_parser.add_argument("--ref_txt", help="target sequence reference file") 270 | argument_parser.add_argument("--nbest_txt", help="onmt_translate output file") 271 | argument_parser.add_argument("--nbest", type=int, help="number of hypos per datum") 272 | argument_parser.add_argument("--outbase", help="the basename of output files") 273 | 274 | 275 | if __name__ == "__main__": 276 | cmdline_parser = argparse.ArgumentParser( 277 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 278 | ) 279 | add_arguments(cmdline_parser) 280 | args = cmdline_parser.parse_args() 281 | 282 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 283 | main( 284 | dialogues_jsonl=args.dialogues_jsonl, 285 | datum_id_jsonl=args.datum_id_jsonl, 286 | src_txt=args.src_txt, 287 | ref_txt=args.ref_txt, 288 | nbest_txt=args.nbest_txt, 289 | nbest=args.nbest, 290 | outbase=args.outbase, 291 | ) 292 | -------------------------------------------------------------------------------- /src/dataflow/onmt_helpers/embeddings_to_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # From OpenNMT: 3 | # 4 | # MIT License 5 | # 6 | # Copyright (c) 2017-Present OpenNMT 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 9 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation 10 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 11 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions 14 | # of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 17 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 19 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | # -*- coding: utf-8 -*- 21 | # pylint: disable=W1201,W1202,C0411 22 | """Copied from OpenNMT-Py/tools/embeddings_to_torch.py.""" 23 | from __future__ import division 24 | 25 | import argparse 26 | 27 | import six 28 | import torch 29 | from onmt.inputters.inputter import _old_style_vocab 30 | from onmt.utils.logging import init_logger, logger 31 | 32 | 33 | def get_vocabs(dict_path): 34 | fields = torch.load(dict_path) 35 | 36 | vocs = [] 37 | for side in ["src", "tgt"]: 38 | if _old_style_vocab(fields): 39 | vocab = next((v for n, v in fields if n == side), None) 40 | else: 41 | try: 42 | vocab = fields[side].base_field.vocab 43 | except AttributeError: 44 | vocab = fields[side].vocab 45 | vocs.append(vocab) 46 | # pylint:disable=W0632 47 | enc_vocab, dec_vocab = vocs 48 | 49 | logger.info("From: %s" % dict_path) 50 | logger.info("\t* source vocab: %d words" % len(enc_vocab)) 51 | logger.info("\t* target vocab: %d words" % len(dec_vocab)) 52 | 53 | return enc_vocab, dec_vocab 54 | 55 | 56 | def read_embeddings(file_enc, skip_lines=0, filter_set=None): 57 | embs = dict() 58 | total_vectors_in_file = 0 59 | with open(file_enc, "rb") as f: 60 | for i, line in enumerate(f): 61 | if i < skip_lines: 62 | continue 63 | if not line: 64 | break 65 | if len(line) == 0: 66 | # is this reachable? 67 | continue 68 | 69 | l_split = line.decode("utf8").strip().split(" ") 70 | if len(l_split) == 2: 71 | continue 72 | total_vectors_in_file += 1 73 | if filter_set is not None and l_split[0] not in filter_set: 74 | continue 75 | embs[l_split[0]] = [float(em) for em in l_split[1:]] 76 | return embs, total_vectors_in_file 77 | 78 | 79 | def convert_to_torch_tensor(word_to_float_list_dict, vocab): 80 | dim = len(six.next(six.itervalues(word_to_float_list_dict))) 81 | tensor = torch.zeros((len(vocab), dim)) 82 | for word, values in word_to_float_list_dict.items(): 83 | tensor[vocab.stoi[word]] = torch.Tensor(values) 84 | return tensor 85 | 86 | 87 | def calc_vocab_load_stats(vocab, loaded_embed_dict): 88 | matching_count = len(set(vocab.stoi.keys()) & set(loaded_embed_dict.keys())) 89 | missing_count = len(vocab) - matching_count 90 | percent_matching = matching_count / len(vocab) * 100 91 | return matching_count, missing_count, percent_matching 92 | 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser(description="embeddings_to_torch.py") 96 | parser.add_argument( 97 | "-emb_file_both", 98 | required=False, 99 | help="loads Embeddings for both source and target " "from this file.", 100 | ) 101 | parser.add_argument( 102 | "-emb_file_enc", required=False, help="source Embeddings from this file" 103 | ) 104 | parser.add_argument( 105 | "-emb_file_dec", required=False, help="target Embeddings from this file" 106 | ) 107 | parser.add_argument( 108 | "-output_file", required=True, help="Output file for the prepared data" 109 | ) 110 | parser.add_argument("-dict_file", required=True, help="Dictionary file") 111 | parser.add_argument("-verbose", action="store_true", default=False) 112 | parser.add_argument( 113 | "-skip_lines", 114 | type=int, 115 | default=0, 116 | help="Skip first lines of the embedding file", 117 | ) 118 | parser.add_argument("-type", choices=["GloVe", "word2vec"], default="GloVe") 119 | opt = parser.parse_args() 120 | 121 | enc_vocab, dec_vocab = get_vocabs(opt.dict_file) 122 | 123 | # Read in embeddings 124 | skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines 125 | if opt.emb_file_both is not None: 126 | if opt.emb_file_enc is not None: 127 | raise ValueError( 128 | "If --emb_file_both is passed in, you should not" "set --emb_file_enc." 129 | ) 130 | if opt.emb_file_dec is not None: 131 | raise ValueError( 132 | "If --emb_file_both is passed in, you should not" "set --emb_file_dec." 133 | ) 134 | set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set( 135 | dec_vocab.stoi.keys() 136 | ) 137 | logger.info( 138 | "Reading encoder and decoder embeddings from {}".format(opt.emb_file_both) 139 | ) 140 | src_vectors, total_vec_count = read_embeddings( 141 | opt.emb_file_both, skip_lines, set_of_src_and_tgt_vocab 142 | ) 143 | tgt_vectors = src_vectors 144 | logger.info("\tFound {} total vectors in file".format(total_vec_count)) 145 | else: 146 | if opt.emb_file_enc is None: 147 | raise ValueError( 148 | "If --emb_file_enc not provided. Please specify " 149 | "the file with encoder embeddings, or pass in " 150 | "--emb_file_both" 151 | ) 152 | if opt.emb_file_dec is None: 153 | raise ValueError( 154 | "If --emb_file_dec not provided. Please specify " 155 | "the file with encoder embeddings, or pass in " 156 | "--emb_file_both" 157 | ) 158 | logger.info("Reading encoder embeddings from {}".format(opt.emb_file_enc)) 159 | src_vectors, total_vec_count = read_embeddings( 160 | opt.emb_file_enc, skip_lines, filter_set=enc_vocab.stoi 161 | ) 162 | logger.info("\tFound {} total vectors in file.".format(total_vec_count)) 163 | logger.info("Reading decoder embeddings from {}".format(opt.emb_file_dec)) 164 | tgt_vectors, total_vec_count = read_embeddings( 165 | opt.emb_file_dec, skip_lines, filter_set=dec_vocab.stoi 166 | ) 167 | logger.info("\tFound {} total vectors in file".format(total_vec_count)) 168 | logger.info("After filtering to vectors in vocab:") 169 | logger.info( 170 | "\t* enc: %d match, %d missing, (%.2f%%)" 171 | % calc_vocab_load_stats(enc_vocab, src_vectors) 172 | ) 173 | logger.info( 174 | "\t* dec: %d match, %d missing, (%.2f%%)" 175 | % calc_vocab_load_stats(dec_vocab, src_vectors) 176 | ) 177 | 178 | # Write to file 179 | enc_output_file = opt.output_file + ".enc.pt" 180 | dec_output_file = opt.output_file + ".dec.pt" 181 | logger.info( 182 | "\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" 183 | % (enc_output_file, dec_output_file) 184 | ) 185 | torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) 186 | torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) 187 | logger.info("\nDone.") 188 | 189 | 190 | if __name__ == "__main__": 191 | init_logger("embeddings_to_torch.log") 192 | main() 193 | -------------------------------------------------------------------------------- /src/dataflow/onmt_helpers/evaluate_onmt_predictions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | """ 4 | Semantic Machines\N{TRADE MARK SIGN} software. 5 | 6 | Evaluates 1best predictions. 7 | 8 | Computes both turn-level and dialogue-level accuracy. 9 | """ 10 | 11 | import argparse 12 | import csv 13 | import json 14 | from dataclasses import dataclass 15 | from typing import List, Optional, Tuple 16 | 17 | import jsons 18 | import pandas as pd 19 | 20 | from dataflow.core.dialogue import TurnId 21 | from dataflow.core.io_utils import load_jsonl_file 22 | 23 | 24 | @dataclass 25 | class EvaluationScores: 26 | num_total_turns: int = 0 27 | num_correct_turns: int = 0 28 | num_turns_before_first_error: int = 0 29 | num_total_dialogues: int = 0 30 | num_correct_dialogues: int = 0 31 | 32 | @property 33 | def accuracy(self) -> float: 34 | if self.num_total_turns == 0: 35 | return 0 36 | return self.num_correct_turns / self.num_total_turns 37 | 38 | @property 39 | def ave_num_turns_before_first_error(self) -> float: 40 | if self.num_total_dialogues == 0: 41 | return 0 42 | return self.num_turns_before_first_error / self.num_total_dialogues 43 | 44 | @property 45 | def pct_correct_dialogues(self) -> float: 46 | if self.num_total_dialogues == 0: 47 | return 0 48 | return self.num_correct_dialogues / self.num_total_dialogues 49 | 50 | def __iadd__(self, other: object) -> "EvaluationScores": 51 | if not isinstance(other, EvaluationScores): 52 | raise ValueError() 53 | self.num_total_turns += other.num_total_turns 54 | self.num_correct_turns += other.num_correct_turns 55 | self.num_turns_before_first_error += other.num_turns_before_first_error 56 | self.num_total_dialogues += other.num_total_dialogues 57 | self.num_correct_dialogues += other.num_correct_dialogues 58 | 59 | return self 60 | 61 | def __add__(self, other: object) -> "EvaluationScores": 62 | if not isinstance(other, EvaluationScores): 63 | raise ValueError() 64 | result = EvaluationScores() 65 | result += self 66 | result += other 67 | 68 | return result 69 | 70 | 71 | def evaluate_dialogue(turns: List[Tuple[int, bool]]) -> EvaluationScores: 72 | num_correct_turns = 0 73 | dialogue_is_correct = True 74 | num_turns_before_first_error = 0 75 | seen_error = False 76 | for _turn_index, is_correct in sorted(turns, key=lambda x: x[0]): 77 | if is_correct: 78 | num_correct_turns += 1 79 | if not seen_error: 80 | num_turns_before_first_error += 1 81 | else: 82 | dialogue_is_correct = False 83 | seen_error = True 84 | 85 | return EvaluationScores( 86 | num_total_turns=len(turns), 87 | num_correct_turns=num_correct_turns, 88 | num_turns_before_first_error=num_turns_before_first_error, 89 | num_total_dialogues=1, 90 | num_correct_dialogues=1 if dialogue_is_correct else 0, 91 | ) 92 | 93 | 94 | def evaluate_dataset( 95 | prediction_report_df: pd.DataFrame, field_name: str, 96 | ) -> EvaluationScores: 97 | # pylint: disable=singleton-comparison 98 | dataset_scores = EvaluationScores() 99 | for _dialogue_id, df_for_dialogue in prediction_report_df.groupby("dialogueId"): 100 | turns = [ 101 | (int(row.get("turnIndex")), row.get(field_name)) 102 | for _, row in df_for_dialogue.iterrows() 103 | ] 104 | dialogue_scores = evaluate_dialogue(turns) 105 | dataset_scores += dialogue_scores 106 | 107 | return dataset_scores 108 | 109 | 110 | def main( 111 | prediction_report_tsv: str, 112 | datum_ids_jsonl: Optional[str], 113 | use_leaderboard_metric: bool, 114 | scores_json: str, 115 | ) -> None: 116 | prediction_report_df = pd.read_csv( 117 | prediction_report_tsv, 118 | sep="\t", 119 | encoding="utf-8", 120 | quoting=csv.QUOTE_ALL, 121 | na_values=None, 122 | keep_default_na=False, 123 | ) 124 | assert not prediction_report_df.isnull().any().any() 125 | 126 | if datum_ids_jsonl: 127 | datum_ids = set( 128 | load_jsonl_file(data_jsonl=datum_ids_jsonl, cls=TurnId, verbose=False) 129 | ) 130 | mask_datum_id = [ 131 | TurnId(dialogue_id=row.get("dialogueId"), turn_index=row.get("turnIndex")) 132 | in datum_ids 133 | for _, row in prediction_report_df.iterrows() 134 | ] 135 | prediction_report_df = prediction_report_df.loc[mask_datum_id] 136 | 137 | if use_leaderboard_metric: 138 | scores_not_ignoring_refer = evaluate_dataset( 139 | prediction_report_df, "isCorrectLeaderboard" 140 | ) 141 | scores_ignoring_refer = evaluate_dataset( 142 | prediction_report_df, "isCorrectLeaderboardIgnoringRefer" 143 | ) 144 | else: 145 | scores_not_ignoring_refer = evaluate_dataset(prediction_report_df, "isCorrect") 146 | scores_ignoring_refer = evaluate_dataset( 147 | prediction_report_df, "isCorrectIgnoringRefer" 148 | ) 149 | 150 | scores_dict = { 151 | "notIgnoringRefer": jsons.dump(scores_not_ignoring_refer), 152 | "ignoringRefer": jsons.dump(scores_ignoring_refer), 153 | } 154 | with open(scores_json, "w") as fp: 155 | fp.write(json.dumps(scores_dict, indent=2)) 156 | fp.write("\n") 157 | 158 | 159 | def add_arguments(argument_parser: argparse.ArgumentParser) -> None: 160 | argument_parser.add_argument( 161 | "--prediction_report_tsv", help="the prediction report tsv file" 162 | ) 163 | argument_parser.add_argument( 164 | "--datum_ids_jsonl", default=None, help="if set, only evaluate on these turns", 165 | ) 166 | argument_parser.add_argument( 167 | "--use_leaderboard_metric", 168 | default=False, 169 | action="store_true", 170 | help="if set, use the isCorrectLeaderboard field instead of isCorrect field in the prediction report", 171 | ) 172 | argument_parser.add_argument("--scores_json", help="output scores json file") 173 | 174 | 175 | if __name__ == "__main__": 176 | cmdline_parser = argparse.ArgumentParser( 177 | description=__doc__, formatter_class=argparse.RawTextHelpFormatter 178 | ) 179 | add_arguments(cmdline_parser) 180 | args = cmdline_parser.parse_args() 181 | 182 | print("Semantic Machines\N{TRADE MARK SIGN} software.") 183 | if not args.use_leaderboard_metric: 184 | print( 185 | "WARNING: The flag --use_leaderboard_metric is not set." 186 | " The reported results will be consistent with the numbers" 187 | " reported in the TACL2020 paper. To report on the leaderboard evaluation metric, please use" 188 | " --use_leaderboard_metric, which canonicalizes the labels and predictions." 189 | ) 190 | main( 191 | prediction_report_tsv=args.prediction_report_tsv, 192 | datum_ids_jsonl=args.datum_ids_jsonl, 193 | use_leaderboard_metric=args.use_leaderboard_metric, 194 | scores_json=args.scores_json, 195 | ) 196 | -------------------------------------------------------------------------------- /src/dataflow/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/task_oriented_dialogue_as_dataflow_synthesis/bbd16fe69687b1052a7b5937ee7d20b641f2e642/src/dataflow/py.typed -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from os.path import dirname, join 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def data_dir() -> str: 10 | base_dir = dirname(__file__) 11 | return join(base_dir, "data") 12 | -------------------------------------------------------------------------------- /tests/data/multiwoz_2_1/PMUL3470.json: -------------------------------------------------------------------------------- 1 | { 2 | "dialogue_idx": "PMUL3470.json", 3 | "domains": [ 4 | "train", 5 | "hotel" 6 | ], 7 | "dialogue": [ 8 | { 9 | "system_transcript": "", 10 | "turn_idx": 0, 11 | "belief_state": [ 12 | { 13 | "slots": [ 14 | [ 15 | "hotel-name", 16 | "express by holiday inn cambridge" 17 | ] 18 | ], 19 | "act": "inform" 20 | } 21 | ], 22 | "turn_label": [ 23 | [ 24 | "hotel-name", 25 | "express by holiday inn cambridge" 26 | ] 27 | ], 28 | "transcript": "i am looking for a specific hotel , its name is express by holiday inn cambridge", 29 | "system_acts": [], 30 | "domain": "hotel" 31 | }, 32 | { 33 | "system_transcript": "i have the express by holiday inn cambridge located on 1517 norman way , coldhams business park . their phone number is 01223866800 . would you like to know anything else ?", 34 | "turn_idx": 1, 35 | "belief_state": [ 36 | { 37 | "slots": [ 38 | [ 39 | "hotel-book people", 40 | "7" 41 | ] 42 | ], 43 | "act": "inform" 44 | }, 45 | { 46 | "slots": [ 47 | [ 48 | "hotel-name", 49 | "express by holiday inn cambridge" 50 | ] 51 | ], 52 | "act": "inform" 53 | } 54 | ], 55 | "turn_label": [ 56 | [ 57 | "hotel-book people", 58 | "7" 59 | ] 60 | ], 61 | "transcript": "yes , could you book the hotel room for me for 7 people ?", 62 | "system_acts": [ 63 | [ 64 | "name", 65 | "express by holiday inn cambridge" 66 | ], 67 | [ 68 | "addr", 69 | "1517 norman way" 70 | ], 71 | [ 72 | "addr", 73 | "coldhams business park" 74 | ], 75 | [ 76 | "phone", 77 | "01223866800" 78 | ] 79 | ], 80 | "domain": "hotel" 81 | }, 82 | { 83 | "system_transcript": "yes , of course . what day would you like to stay ?", 84 | "turn_idx": 2, 85 | "belief_state": [ 86 | { 87 | "slots": [ 88 | [ 89 | "hotel-book day", 90 | "monday" 91 | ] 92 | ], 93 | "act": "inform" 94 | }, 95 | { 96 | "slots": [ 97 | [ 98 | "hotel-book people", 99 | "7" 100 | ] 101 | ], 102 | "act": "inform" 103 | }, 104 | { 105 | "slots": [ 106 | [ 107 | "hotel-book stay", 108 | "4" 109 | ] 110 | ], 111 | "act": "inform" 112 | }, 113 | { 114 | "slots": [ 115 | [ 116 | "hotel-name", 117 | "express by holiday inn cambridge" 118 | ] 119 | ], 120 | "act": "inform" 121 | } 122 | ], 123 | "turn_label": [ 124 | [ 125 | "hotel-book day", 126 | "monday" 127 | ], 128 | [ 129 | "hotel-book stay", 130 | "4" 131 | ] 132 | ], 133 | "transcript": "monday , please . there will be 7 of us and we'd like to stay for 4 days .", 134 | "system_acts": [ 135 | "day" 136 | ], 137 | "domain": "hotel" 138 | }, 139 | { 140 | "system_transcript": "here is the booking information:booking was successful . reference number is : 5f8g6j1g", 141 | "turn_idx": 3, 142 | "belief_state": [ 143 | { 144 | "slots": [ 145 | [ 146 | "hotel-book day", 147 | "monday" 148 | ] 149 | ], 150 | "act": "inform" 151 | }, 152 | { 153 | "slots": [ 154 | [ 155 | "hotel-book people", 156 | "7" 157 | ] 158 | ], 159 | "act": "inform" 160 | }, 161 | { 162 | "slots": [ 163 | [ 164 | "hotel-book stay", 165 | "4" 166 | ] 167 | ], 168 | "act": "inform" 169 | }, 170 | { 171 | "slots": [ 172 | [ 173 | "hotel-name", 174 | "express by holiday inn cambridge" 175 | ] 176 | ], 177 | "act": "inform" 178 | } 179 | ], 180 | "turn_label": [], 181 | "transcript": "thank you . i would also like to book a train , please .", 182 | "system_acts": [], 183 | "domain": "train" 184 | }, 185 | { 186 | "system_transcript": "sure , which stations will you be traveling between ?", 187 | "turn_idx": 4, 188 | "belief_state": [ 189 | { 190 | "slots": [ 191 | [ 192 | "hotel-book day", 193 | "monday" 194 | ] 195 | ], 196 | "act": "inform" 197 | }, 198 | { 199 | "slots": [ 200 | [ 201 | "hotel-book people", 202 | "7" 203 | ] 204 | ], 205 | "act": "inform" 206 | }, 207 | { 208 | "slots": [ 209 | [ 210 | "hotel-book stay", 211 | "4" 212 | ] 213 | ], 214 | "act": "inform" 215 | }, 216 | { 217 | "slots": [ 218 | [ 219 | "hotel-name", 220 | "express by holiday inn cambridge" 221 | ] 222 | ], 223 | "act": "inform" 224 | }, 225 | { 226 | "slots": [ 227 | [ 228 | "train-destination", 229 | "birmingham new street" 230 | ] 231 | ], 232 | "act": "inform" 233 | }, 234 | { 235 | "slots": [ 236 | [ 237 | "train-day", 238 | "none" 239 | ] 240 | ], 241 | "act": "inform" 242 | }, 243 | { 244 | "slots": [ 245 | [ 246 | "train-departure", 247 | "cambridge" 248 | ] 249 | ], 250 | "act": "inform" 251 | } 252 | ], 253 | "turn_label": [ 254 | [ 255 | "train-destination", 256 | "birmingham new street" 257 | ], 258 | [ 259 | "train-day", 260 | "none" 261 | ], 262 | [ 263 | "train-departure", 264 | "cambridge" 265 | ] 266 | ], 267 | "transcript": "i will be going from cambridge to birmingham new street .", 268 | "system_acts": [ 269 | "dest", 270 | "depart" 271 | ], 272 | "domain": "train" 273 | }, 274 | { 275 | "system_transcript": "what time would you like to leave ? the trains depart every hour .", 276 | "turn_idx": 5, 277 | "belief_state": [ 278 | { 279 | "slots": [ 280 | [ 281 | "hotel-book day", 282 | "monday" 283 | ] 284 | ], 285 | "act": "inform" 286 | }, 287 | { 288 | "slots": [ 289 | [ 290 | "hotel-book people", 291 | "7" 292 | ] 293 | ], 294 | "act": "inform" 295 | }, 296 | { 297 | "slots": [ 298 | [ 299 | "hotel-book stay", 300 | "4" 301 | ] 302 | ], 303 | "act": "inform" 304 | }, 305 | { 306 | "slots": [ 307 | [ 308 | "hotel-name", 309 | "express by holiday inn cambridge" 310 | ] 311 | ], 312 | "act": "inform" 313 | }, 314 | { 315 | "slots": [ 316 | [ 317 | "train-destination", 318 | "birmingham new street" 319 | ] 320 | ], 321 | "act": "inform" 322 | }, 323 | { 324 | "slots": [ 325 | [ 326 | "train-day", 327 | "friday" 328 | ] 329 | ], 330 | "act": "inform" 331 | }, 332 | { 333 | "slots": [ 334 | [ 335 | "train-arriveby", 336 | "17:30" 337 | ] 338 | ], 339 | "act": "inform" 340 | }, 341 | { 342 | "slots": [ 343 | [ 344 | "train-departure", 345 | "cambridge" 346 | ] 347 | ], 348 | "act": "inform" 349 | } 350 | ], 351 | "turn_label": [ 352 | [ 353 | "train-day", 354 | "friday" 355 | ], 356 | [ 357 | "train-arriveby", 358 | "17:30" 359 | ] 360 | ], 361 | "transcript": "whenever will get me there by 17:30 . i do need to leave on friday and i will need the travel time please .", 362 | "system_acts": [ 363 | [ 364 | "leave", 365 | "every hour" 366 | ], 367 | "leave" 368 | ], 369 | "domain": "train" 370 | } 371 | ] 372 | } 373 | -------------------------------------------------------------------------------- /tests/data/multiwoz_2_1/PMUL4478.json: -------------------------------------------------------------------------------- 1 | { 2 | "dialogue_idx": "PMUL4478.json", 3 | "domains": [ 4 | "train", 5 | "restaurant" 6 | ], 7 | "dialogue": [ 8 | { 9 | "system_transcript": "", 10 | "turn_idx": 0, 11 | "belief_state": [], 12 | "turn_label": [], 13 | "transcript": "i am traveling to cambridge and looking forward to try local restaurant -s .", 14 | "system_acts": [], 15 | "domain": "restaurant" 16 | }, 17 | { 18 | "system_transcript": "okay , any type in mind ?", 19 | "turn_idx": 1, 20 | "belief_state": [ 21 | { 22 | "slots": [ 23 | [ 24 | "train-leaveat", 25 | "16:45" 26 | ] 27 | ], 28 | "act": "inform" 29 | }, 30 | { 31 | "slots": [ 32 | [ 33 | "train-departure", 34 | "stevenage" 35 | ] 36 | ], 37 | "act": "inform" 38 | } 39 | ], 40 | "turn_label": [ 41 | [ 42 | "train-leaveat", 43 | "16:45" 44 | ], 45 | [ 46 | "train-departure", 47 | "stevenage" 48 | ] 49 | ], 50 | "transcript": "the train i am looking for should depart from stevenage and be leaving after 16:45 . can you help ?", 51 | "system_acts": [ 52 | "food" 53 | ], 54 | "domain": "train" 55 | }, 56 | { 57 | "system_transcript": "okay . you mentioned finding a local restaurant , initially . would you like me to search for 1 ? if so , what kind of cuisine ?", 58 | "turn_idx": 2, 59 | "belief_state": [], 60 | "turn_label": [], 61 | "transcript": "i am looking for something local . what do you recommend ?", 62 | "system_acts": [ 63 | "food" 64 | ], 65 | "domain": "train" 66 | }, 67 | { 68 | "system_transcript": "it depends on what your price range is .", 69 | "turn_idx": 3, 70 | "belief_state": [ 71 | { 72 | "slots": [ 73 | [ 74 | "restaurant-food", 75 | "singaporean" 76 | ] 77 | ], 78 | "act": "inform" 79 | }, 80 | { 81 | "slots": [ 82 | [ 83 | "restaurant-pricerange", 84 | "expensive" 85 | ] 86 | ], 87 | "act": "inform" 88 | }, 89 | { 90 | "slots": [ 91 | [ 92 | "restaurant-area", 93 | "centre" 94 | ] 95 | ], 96 | "act": "inform" 97 | } 98 | ], 99 | "turn_label": [ 100 | [ 101 | "restaurant-food", 102 | "singaporean" 103 | ], 104 | [ 105 | "restaurant-pricerange", 106 | "expensive" 107 | ], 108 | [ 109 | "restaurant-area", 110 | "centre" 111 | ] 112 | ], 113 | "transcript": "actually i am craving singaporean food , price is no object . i would also like something in the centre area .", 114 | "system_acts": [ 115 | "price" 116 | ], 117 | "domain": "restaurant" 118 | }, 119 | { 120 | "system_transcript": "unfortunately , there is not a singaporean restaurant in the centre of town . any other preferences ?", 121 | "turn_idx": 4, 122 | "belief_state": [ 123 | { 124 | "slots": [ 125 | [ 126 | "restaurant-food", 127 | "asian oriental" 128 | ] 129 | ], 130 | "act": "inform" 131 | }, 132 | { 133 | "slots": [ 134 | [ 135 | "restaurant-pricerange", 136 | "expensive" 137 | ] 138 | ], 139 | "act": "inform" 140 | }, 141 | { 142 | "slots": [ 143 | [ 144 | "restaurant-area", 145 | "centre" 146 | ] 147 | ], 148 | "act": "inform" 149 | } 150 | ], 151 | "turn_label": [ 152 | [ 153 | "restaurant-food", 154 | "asian oriental" 155 | ] 156 | ], 157 | "transcript": "do you have any asian expensive restaurant -s ?", 158 | "system_acts": [], 159 | "domain": "restaurant" 160 | }, 161 | { 162 | "system_transcript": "yes , there s a place called kymmoy .", 163 | "turn_idx": 5, 164 | "belief_state": [ 165 | { 166 | "slots": [ 167 | [ 168 | "restaurant-food", 169 | "asian oriental" 170 | ] 171 | ], 172 | "act": "inform" 173 | }, 174 | { 175 | "slots": [ 176 | [ 177 | "restaurant-pricerange", 178 | "expensive" 179 | ] 180 | ], 181 | "act": "inform" 182 | }, 183 | { 184 | "slots": [ 185 | [ 186 | "restaurant-area", 187 | "centre" 188 | ] 189 | ], 190 | "act": "inform" 191 | } 192 | ], 193 | "turn_label": [], 194 | "transcript": "thank you so much for your information .", 195 | "system_acts": [ 196 | [ 197 | "name", 198 | "kymmoy" 199 | ] 200 | ], 201 | "domain": "restaurant" 202 | } 203 | ] 204 | } 205 | -------------------------------------------------------------------------------- /tests/data/multiwoz_2_1/README.md: -------------------------------------------------------------------------------- 1 | This folder contains example dialogues extracted from MultiWOZ 2.1 data processed by TRADE. 2 | 3 | To extract a dialogue from the TRADE processed json file, you can run 4 | ```bash 5 | jq '.[] | select (.dialogue_idx == "MUL1626.json")' dev_dials.json 6 | ``` 7 | 8 | The [MultiWoZ 2.1 dataset](https://www.repository.cam.ac.uk/handle/1810/294507) has 9 | licensed under a Creative Commons Attribution 4.0 International License. 10 | Creative Commons License 11 | -------------------------------------------------------------------------------- /tests/test_dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /tests/test_dataflow/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /tests/test_dataflow/core/test_linearize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import json 4 | 5 | from dataflow.core.linearize import ( 6 | lispress_to_seq, 7 | seq_to_lispress, 8 | sexp_to_seq, 9 | to_canonical_form, 10 | ) 11 | from dataflow.core.lispress import parse_lispress, render_compact, unnest_line 12 | from dataflow.core.program import Expression, ValueOp 13 | 14 | 15 | def test_sexp_to_seq_is_invertible(): 16 | sexps = [ 17 | [], 18 | ["a"], 19 | ["a", ["b", "c"], "d", [[["e"]]]], 20 | [["a", "b"], [["c"], "d"]], 21 | ] 22 | for orig in sexps: 23 | seq = sexp_to_seq(orig) 24 | result = seq_to_lispress(seq) 25 | assert result == orig 26 | 27 | 28 | def test_unnest_line(): 29 | line = ["#", ["String", '"pizza', "hut", 'fenditton"']] 30 | expected_expression = Expression( 31 | id="[1]", 32 | op=ValueOp( 33 | value=json.dumps(dict(schema="String", underlying="pizza hut fenditton")) 34 | ), 35 | ) 36 | 37 | expressions, _, _, _ = unnest_line(line, 0, ()) 38 | assert len(expressions) == 1 39 | assert expressions[0] == expected_expression 40 | 41 | 42 | def test_linearized_roundtrip(): 43 | """Round-trip tests for s-expression formatter and deformatter.""" 44 | data = [ 45 | ('#(String "singleToken")', '# ( String " singleToken " )'), 46 | ('#(String "multiple tokens")', '# ( String " multiple tokens " )'), 47 | # real data 48 | ( 49 | '((mapGet #(Path "fare.fare_id") (clobberRevise (getSalient (actionIntensionConstraint)) (' 50 | "Constraint[Constraint[flight]]) (Constraint " 51 | ':type (?= #(String "fare"))))))', 52 | '( ( mapGet # ( Path " fare.fare_id " ) ( clobberRevise ( getSalient ( actionIntensionConstraint ) ) ' 53 | "( Constraint[Constraint[flight]] ) ( " 54 | 'Constraint :type ( ?= # ( String " fare " ) ) ) ) ) )', 55 | ), 56 | ] 57 | 58 | for raw_sexp, formatted_sexp in data: 59 | assert lispress_to_seq(parse_lispress(raw_sexp)) == formatted_sexp.split() 60 | assert render_compact(seq_to_lispress(formatted_sexp.split())) == raw_sexp 61 | 62 | 63 | def test_meta(): 64 | assert lispress_to_seq( 65 | parse_lispress("(refer (^(Dynamic) ActionIntensionConstraint))") 66 | ) == [ 67 | "(", 68 | "refer", 69 | "(", 70 | "^", 71 | "(", 72 | "Dynamic", 73 | ")", 74 | "ActionIntensionConstraint", 75 | ")", 76 | ")", 77 | ] 78 | 79 | 80 | def test_meta_to_canonical(): 81 | s = """( Yield ( Execute ( ReviseConstraint ( refer ( ^ ( Dynamic ) roleConstraint ( Path.apply "output" ) ) ) ( ^ ( Event ) ConstraintTypeIntension ) ( Event.showAs_? ( ?= ( ShowAsStatus.OutOfOffice ) ) ) ) ) )""" 82 | assert ( 83 | to_canonical_form(s) 84 | == """(Yield (Execute (ReviseConstraint (refer (^(Dynamic) roleConstraint (Path.apply "output"))) (^(Event) ConstraintTypeIntension) (Event.showAs_? (?= (ShowAsStatus.OutOfOffice))))))""" 85 | ) 86 | 87 | 88 | def test_quoted(): 89 | s = """( foo " bar " )""" 90 | assert to_canonical_form(s) == """(foo "bar")""" 91 | -------------------------------------------------------------------------------- /tests/test_dataflow/core/test_type_inference.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import pytest 4 | 5 | from dataflow.core.definition import Definition 6 | from dataflow.core.lispress import lispress_to_program, parse_lispress 7 | from dataflow.core.program import Program, TypeName 8 | from dataflow.core.type_inference import TypeInferenceError, infer_types 9 | 10 | 11 | def _do_inference_test( 12 | expr: str, expected: str, library: Dict[str, Definition] 13 | ) -> Tuple[Program, Program]: 14 | lispress = parse_lispress(expr) 15 | program, _ = lispress_to_program(lispress, 0) 16 | res = infer_types(program, library) 17 | (expected_program, _) = lispress_to_program(parse_lispress(expected), 0) 18 | return expected_program, res 19 | 20 | 21 | SIMPLE_PLUS_LIBRARY = { 22 | "+": Definition( 23 | "+", ["T"], [("x", TypeName("T")), ("y", TypeName("T")),], TypeName("T") 24 | ), 25 | "plusLong": Definition( 26 | "+", [], [("x", TypeName("Long")), ("y", TypeName("Long")),], TypeName("Long") 27 | ), 28 | "single_element_list": Definition( 29 | "single_element_list", 30 | ["T"], 31 | [("e", TypeName("T"))], 32 | TypeName("List", (TypeName("T"),)), 33 | ), 34 | "NamedArgs": Definition( 35 | "HasNamedArgs", 36 | [], 37 | [ 38 | ("arg1", TypeName("Long")), 39 | ("arg2", TypeName("Long")), 40 | ("arg3", TypeName("String")), 41 | ], 42 | TypeName("Long"), 43 | ), 44 | } 45 | 46 | 47 | def test_simple(): 48 | expected_program, res = _do_inference_test( 49 | "(+ 1 2)", "^Number (^(Number) + ^Number 1 ^Number 2)", SIMPLE_PLUS_LIBRARY, # 50 | ) 51 | assert res == expected_program 52 | 53 | expected_program, res = _do_inference_test( 54 | "(+ (plusLong 3L 1L) 2L)", 55 | "^Long (^(Long) + ^Long (plusLong ^Long 3L ^Long 1L) ^Long 2L)", 56 | SIMPLE_PLUS_LIBRARY, 57 | ) 58 | assert res == expected_program 59 | 60 | 61 | def test_named_args(): 62 | expected_program, res = _do_inference_test( 63 | '(NamedArgs 1L :arg3 "2")', 64 | '^Long (NamedArgs ^Long 1L :arg3 ^String "2")', 65 | SIMPLE_PLUS_LIBRARY, 66 | ) 67 | assert res == expected_program 68 | 69 | 70 | def test_types_disagree(): 71 | with pytest.raises(TypeInferenceError): 72 | _do_inference_test( 73 | "^Number (plusLong 3L 1)", "^Number (plusLong 3L 1)", SIMPLE_PLUS_LIBRARY, 74 | ) 75 | 76 | 77 | def test_ascription_disagrees(): 78 | with pytest.raises(TypeInferenceError): 79 | _do_inference_test( 80 | "^Number (plusLong 3L 1L)", "^Number (plusLong 3L 1L)", SIMPLE_PLUS_LIBRARY, 81 | ) 82 | 83 | 84 | def test_let(): 85 | expected_program, res = _do_inference_test( 86 | "(let (x (+ 1L 2L)) (+ x x))", 87 | "(let (x ^Long (^(Long) + ^Long 1L ^Long 2L)) ^Long (^(Long) + x x))", 88 | SIMPLE_PLUS_LIBRARY, 89 | ) 90 | assert res == expected_program 91 | 92 | 93 | def test_multi_let(): 94 | expected_program, res = _do_inference_test( 95 | "(let (a 1L b 2L x (+ a b)) (+ x x))", 96 | "(let (a ^Long 1L b ^Long 2L x ^Long (^(Long) + a b)) ^Long (^(Long) + x x))", 97 | SIMPLE_PLUS_LIBRARY, 98 | ) 99 | assert res == expected_program 100 | 101 | 102 | def test_parameterized(): 103 | expected_program, res = _do_inference_test( 104 | "(single_element_list 1)", 105 | "^(List Number) (^(Number) single_element_list ^Number 1)", 106 | SIMPLE_PLUS_LIBRARY, 107 | ) 108 | assert res == expected_program 109 | 110 | expected_program, res = _do_inference_test( 111 | '(single_element_list "5")', 112 | '^(List String) (^(String) single_element_list ^String "5")', 113 | SIMPLE_PLUS_LIBRARY, 114 | ) 115 | assert res == expected_program 116 | -------------------------------------------------------------------------------- /tests/test_dataflow/core/test_utterance_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from dataflow.core.utterance_tokenizer import UtteranceTokenizer, tokenize_datetime 4 | 5 | 6 | def test_tokenize_datetime(): 7 | data = [ 8 | ("5.10", "5 . 10"), 9 | ("4:00", "4 : 00"), 10 | ("5/7", "5 / 7"), 11 | ("5\\7", "5 \\ 7"), 12 | ("3-9", "3 - 9"), 13 | ("3pm", "3 pm"), 14 | ] 15 | for text, expected in data: 16 | assert tokenize_datetime(text) == expected 17 | 18 | 19 | def test_tokenize_utterance(): 20 | utterance_tokenizer = UtteranceTokenizer() 21 | 22 | data = [ 23 | ( 24 | "Reschedule meeting with Barack Obama to 5/30/2019 at 3:00pm", 25 | [ 26 | "Reschedule", 27 | "meeting", 28 | "with", 29 | "Barack", 30 | "Obama", 31 | "to", 32 | "5", 33 | "/", 34 | "30", 35 | "/", 36 | "2019", 37 | "at", 38 | "3", 39 | ":", 40 | "00", 41 | "pm", 42 | ], 43 | ), 44 | ( 45 | "Can you also add icecream birthday tomorrow at 6PM?", 46 | [ 47 | "Can", 48 | "you", 49 | "also", 50 | "add", 51 | "icecream", 52 | "birthday", 53 | "tomorrow", 54 | "at", 55 | "6", 56 | "PM", 57 | "?", 58 | ], 59 | ), 60 | ] 61 | for text, expected in data: 62 | assert utterance_tokenizer.tokenize(text) == expected 63 | -------------------------------------------------------------------------------- /tests/test_dataflow/multiwoz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /tests/test_dataflow/multiwoz/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from typing import Any, Dict, List, Tuple 4 | 5 | import pytest 6 | 7 | from dataflow.multiwoz.trade_dst_utils import BeliefState 8 | 9 | 10 | def convert_belief_dict_to_belief_state(belief_dict: Dict[str, str]) -> BeliefState: 11 | belief_state: BeliefState = [] 12 | for slot_fullname, slot_value in sorted(belief_dict.items()): 13 | belief_state.append({"slots": [[slot_fullname, slot_value]]}) 14 | return belief_state 15 | 16 | 17 | def build_trade_dialogue( 18 | dialogue_id: str, turns: List[Tuple[str, str, Dict[str, str]]] 19 | ) -> Dict[str, Any]: 20 | trade_dialogue = { 21 | "dialogue_idx": dialogue_id, 22 | "dialogue": [ 23 | { 24 | # Our mock dialogues here use 1-based turn indices. 25 | # In real MultiWOZ/TRADE dialogues, turn index starts from 0. 26 | "turn_idx": turn_idx + 1, 27 | "system_transcript": agent_utt, 28 | "transcript": user_utt, 29 | "belief_state": convert_belief_dict_to_belief_state(belief_dict), 30 | } 31 | for turn_idx, (agent_utt, user_utt, belief_dict) in enumerate(turns) 32 | ], 33 | } 34 | return trade_dialogue 35 | 36 | 37 | @pytest.fixture 38 | def trade_dialogue_1() -> Dict[str, Any]: 39 | return build_trade_dialogue( 40 | dialogue_id="dummy_1", 41 | turns=[ 42 | # turn 1 43 | # activate a domain without constraint, the plan should call "Find" with "EqualityConstraint" 44 | # we intentionally to only put two "none" slots in the belief state to match the MultiWoZ annotation style 45 | ( 46 | "", 47 | "i want to book a hotel", 48 | {"hotel-name": "none", "hotel-type": "none"}, 49 | ), 50 | # turn 2 51 | # add constraints, the plan should call "Revise" with "EqualityConstraint" 52 | ( 53 | "ok what type", 54 | "guest house and cheap, probably hilton", 55 | { 56 | "hotel-name": "hilton", 57 | "hotel-pricerange": "cheap", 58 | "hotel-type": "guest house", 59 | }, 60 | ), 61 | # turn 3 62 | # drop a constraint (but the domain is still active), the plan should call "Revise" with "EqualityConstraint" 63 | ( 64 | "no results", 65 | "ok try another hotel", 66 | { 67 | "hotel-name": "none", 68 | "hotel-pricerange": "cheap", 69 | "hotel-type": "guest house", 70 | }, 71 | ), 72 | # turn 4 73 | # drop the domain 74 | ("failed", "ok never mind", {}), 75 | # turn 5 76 | # activate the domain again 77 | ("sure", "can you find a hotel in west", {"hotel-area": "west"}), 78 | # turn 6 79 | # activate a new domain and use a refer call 80 | ( 81 | "how about this", 82 | "ok can you find a restaurant in the same area", 83 | {"hotel-area": "west", "restaurant-area": "west"}, 84 | ), 85 | # turn 7 86 | # use a refer call to get a value from a dead domain 87 | # the salience model should find the first valid refer value (skips "none") 88 | ( 89 | "how about this", 90 | "use the same price range as the hotel", 91 | { 92 | "hotel-area": "west", 93 | "restaurant-area": "west", 94 | "restaurant-pricerange": "cheap", 95 | }, 96 | ), 97 | # turn 8 98 | # do not change belief state 99 | ( 100 | "ok", 101 | "give me the address", 102 | { 103 | "hotel-area": "west", 104 | "restaurant-area": "west", 105 | "restaurant-pricerange": "cheap", 106 | }, 107 | ), 108 | # turn 9 109 | # a new domain 110 | ( 111 | "ok", 112 | "book a taxi now", 113 | { 114 | "hotel-area": "west", 115 | "restaurant-area": "west", 116 | "restaurant-pricerange": "cheap", 117 | "taxi-departure": "none", 118 | }, 119 | ), 120 | # turn 10 121 | # do not change belief state (make sure the plan is "Revise" not "Find") 122 | ( 123 | "ok", 124 | "ok", 125 | { 126 | "hotel-area": "west", 127 | "restaurant-area": "west", 128 | "restaurant-pricerange": "cheap", 129 | "taxi-departure": "none", 130 | }, 131 | ), 132 | ], 133 | ) 134 | -------------------------------------------------------------------------------- /tests/test_dataflow/multiwoz/test_cli_workflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import json 4 | import os 5 | 6 | import jsons 7 | 8 | from dataflow.multiwoz.create_belief_state_prediction_report import ( 9 | main as create_belief_state_prediction_report, 10 | ) 11 | from dataflow.multiwoz.create_belief_state_tracker_data import ( 12 | main as create_belief_state_tracker_data, 13 | ) 14 | from dataflow.multiwoz.create_programs import main as create_programs 15 | from dataflow.multiwoz.evaluate_belief_state_predictions import EvaluationStats 16 | from dataflow.multiwoz.evaluate_belief_state_predictions import ( 17 | main as evaluate_belief_state_predictions, 18 | ) 19 | from dataflow.multiwoz.execute_programs import main as execute_programs 20 | from dataflow.multiwoz.patch_trade_dialogues import main as patch_trade_dialogues 21 | from test_dataflow.multiwoz.test_create_programs import load_test_trade_dialogues 22 | 23 | 24 | def test_cli_workflow(data_dir: str, tmp_path: str): 25 | """An end-to-end test on the CLI workflow. 26 | 27 | This test involves multiple CLI steps. 28 | 1. patch_trade_dialogues 29 | 2. create_programs 30 | 3. execute_programs 31 | 4. create_belief_state_tracker_data 32 | 5. evaluate_belief_state_predictions 33 | 34 | It does not test all corner cases but it should catch some common errors in the workflow. 35 | """ 36 | # ============ 37 | # merges all test dialogues 38 | # ============ 39 | trade_data_file = os.path.join(tmp_path, "s010.merged_trade_dials.jsonl") 40 | trade_dialogues = list(load_test_trade_dialogues(data_dir)) 41 | with open(trade_data_file, "w") as fp: 42 | fp.write(json.dumps(trade_dialogues, indent=2)) 43 | fp.write("\n") 44 | 45 | # ============ 46 | # patches TRADE dialogues 47 | # ============ 48 | patched_dials_file, _ = patch_trade_dialogues( 49 | trade_data_file=trade_data_file, outbase=os.path.join(tmp_path, "s020.merged") 50 | ) 51 | 52 | # ============ 53 | # create programs for TRADE dialogues 54 | # ============ 55 | dataflow_dialogues = create_programs( 56 | trade_data_file=patched_dials_file, 57 | keep_all_domains=True, 58 | remove_none=False, 59 | fill_none=False, 60 | no_refer=False, 61 | no_revise=False, 62 | avoid_empty_plan=False, 63 | outbase=os.path.join(tmp_path, "s030.merged"), 64 | ) 65 | 66 | # ============ 67 | # execute programs 68 | # ============ 69 | complete_execution_results_file, _, _ = execute_programs( 70 | dialogues_file=dataflow_dialogues, 71 | no_revise=False, 72 | no_refer=False, 73 | cheating_mode="never", 74 | cheating_execution_results_file=None, 75 | outbase=os.path.join(tmp_path, "s040.merged"), 76 | ) 77 | another_complete_execution_results_file, _, _ = execute_programs( 78 | dialogues_file=dataflow_dialogues, 79 | no_revise=False, 80 | no_refer=False, 81 | cheating_mode="always", 82 | cheating_execution_results_file=complete_execution_results_file, 83 | outbase=os.path.join(tmp_path, "s040.merged"), 84 | ) 85 | # because we use the execution results from the dataflow_dialogues itself as the cheating_execution_results_file, 86 | # the outcome should be identical 87 | for actual, expected in zip( 88 | open(another_complete_execution_results_file), 89 | open(complete_execution_results_file), 90 | ): 91 | assert actual == expected 92 | 93 | # ============ 94 | # creates belief state tracker data 95 | # ============ 96 | gold_data_file = os.path.join(tmp_path, "s050.merged_gold.data.jsonl") 97 | create_belief_state_tracker_data( 98 | trade_data_file=patched_dials_file, 99 | belief_state_tracker_data_file=gold_data_file, 100 | ) 101 | prediction_report_jsonl = create_belief_state_prediction_report( 102 | input_data_file=complete_execution_results_file, 103 | file_format="dataflow", 104 | remove_none=False, 105 | gold_data_file=gold_data_file, 106 | outbase=os.path.join(tmp_path, "s050.merged_hypo"), 107 | ) 108 | 109 | # ============ 110 | # computes the accuracy 111 | # ============ 112 | scores_file = evaluate_belief_state_predictions( 113 | prediction_report_jsonl=prediction_report_jsonl, 114 | outbase=os.path.join(tmp_path, "s060.merged"), 115 | ) 116 | stats = jsons.loads(open(scores_file).read(), EvaluationStats) 117 | assert stats.accuracy == 1.0 118 | for _, accuracy in stats.accuracy_for_slot.items(): 119 | assert accuracy == 1.0 120 | -------------------------------------------------------------------------------- /tests/test_dataflow/multiwoz/test_evaluate_belief_state_predictions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from dataflow.multiwoz.evaluate_belief_state_predictions import EvaluationStats 4 | 5 | 6 | def test_evaluation_stats(): 7 | stats = EvaluationStats() 8 | assert stats.accuracy == 0 9 | assert not stats.accuracy_for_slot 10 | 11 | another_stats = EvaluationStats( 12 | num_total_turns=5, 13 | num_correct_turns=2, 14 | num_correct_turns_for_slot={"a": 1, "b": 2}, 15 | num_total_dialogues=2, 16 | num_correct_dialogues=1, 17 | ) 18 | stats += another_stats 19 | assert stats == another_stats 20 | assert stats.accuracy == 0.4 21 | assert set(stats.accuracy_for_slot.keys()) == {"a", "b"} 22 | assert stats.accuracy_for_slot["a"] == 0.2 23 | assert stats.accuracy_for_slot["b"] == 0.4 24 | -------------------------------------------------------------------------------- /tests/test_dataflow/multiwoz/test_execute_programs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from typing import Any, Dict 4 | 5 | from dataflow.core.utterance_tokenizer import UtteranceTokenizer 6 | from dataflow.multiwoz.create_belief_state_tracker_data import ( 7 | build_belief_state_from_belief_dict, 8 | build_belief_state_from_trade_turn, 9 | ) 10 | from dataflow.multiwoz.create_programs import create_programs_for_trade_dialogue 11 | from dataflow.multiwoz.execute_programs import execute_programs_for_dialogue 12 | from dataflow.multiwoz.salience_model import VanillaSalienceModel 13 | from test_dataflow.multiwoz.conftest import build_trade_dialogue 14 | 15 | 16 | def test_execute_programs(trade_dialogue_1: Dict[str, Any]): 17 | utterance_tokenizer = UtteranceTokenizer() 18 | salience_model = VanillaSalienceModel() 19 | 20 | # ============================ 21 | # get cheating execution results 22 | # ============================ 23 | dataflow_dialogue, _, _ = create_programs_for_trade_dialogue( 24 | trade_dialogue=trade_dialogue_1, 25 | keep_all_domains=True, 26 | remove_none=False, 27 | fill_none=False, 28 | salience_model=salience_model, 29 | no_revise=False, 30 | avoid_empty_plan=False, 31 | utterance_tokenizer=utterance_tokenizer, 32 | ) 33 | complete_execution_results, cheating_turn_indices = execute_programs_for_dialogue( 34 | dialogue=dataflow_dialogue, 35 | salience_model=salience_model, 36 | no_revise=False, 37 | cheating_mode="never", 38 | cheating_execution_results=None, 39 | ) 40 | assert not cheating_turn_indices 41 | for trade_turn, complete_execution_result in zip( 42 | trade_dialogue_1["dialogue"], complete_execution_results 43 | ): 44 | assert build_belief_state_from_trade_turn( 45 | trade_turn 46 | ) == build_belief_state_from_belief_dict( 47 | complete_execution_result.belief_dict, strict=True 48 | ) 49 | # pylint: disable=no-member 50 | cheating_execution_results = { 51 | turn.turn_index: complete_execution_result 52 | for turn, complete_execution_result in zip( 53 | dataflow_dialogue.turns, complete_execution_results 54 | ) 55 | } 56 | 57 | # ============================ 58 | # mock the belief state predictions 59 | # ============================ 60 | mock_belief_states = [ 61 | # turn 1: correct 62 | {"hotel-name": "none", "hotel-type": "none"}, 63 | # turn 2: correct 64 | { 65 | "hotel-name": "hilton", 66 | "hotel-pricerange": "cheap", 67 | "hotel-type": "guest house", 68 | }, 69 | # turn 3: change a slot 70 | { 71 | "hotel-name": "none", 72 | "hotel-pricerange": "_cheap", 73 | "hotel-type": "guest house", 74 | }, 75 | # turn 4: add a slot 76 | {"hotel-type": "_added"}, 77 | # turn 5: correct 78 | {"hotel-area": "west"}, 79 | # turn 6: correct 80 | {"hotel-area": "west", "restaurant-area": "west"}, 81 | # turn 7: correct 82 | { 83 | "hotel-area": "west", 84 | "restaurant-area": "west", 85 | "restaurant-pricerange": "cheap", 86 | }, 87 | # turn 8: change two slots 88 | { 89 | "hotel-area": "_west", 90 | "restaurant-area": "_west", 91 | "restaurant-pricerange": "cheap", 92 | }, 93 | # turn 9: correct 94 | { 95 | "hotel-area": "west", 96 | "restaurant-area": "west", 97 | "restaurant-pricerange": "cheap", 98 | "taxi-departure": "none", 99 | }, 100 | # turn 10: drop a slot 101 | { 102 | "hotel-area": "west", 103 | "restaurant-area": "west", 104 | "restaurant-pricerange": "cheap", 105 | }, 106 | ] 107 | mock_trade_dialogue = build_trade_dialogue( 108 | dialogue_id="mock", 109 | turns=[("", "", belief_state) for belief_state in mock_belief_states], 110 | ) 111 | 112 | mock_dataflow_dialogue, _, _ = create_programs_for_trade_dialogue( 113 | trade_dialogue=mock_trade_dialogue, 114 | keep_all_domains=True, 115 | remove_none=False, 116 | fill_none=False, 117 | salience_model=salience_model, 118 | no_revise=False, 119 | avoid_empty_plan=False, 120 | utterance_tokenizer=utterance_tokenizer, 121 | ) 122 | _, mock_cheating_turn_indices = execute_programs_for_dialogue( 123 | dialogue=mock_dataflow_dialogue, 124 | salience_model=salience_model, 125 | no_revise=False, 126 | cheating_mode="always", 127 | cheating_execution_results=cheating_execution_results, 128 | ) 129 | assert mock_cheating_turn_indices == [ 130 | turn.turn_index for turn in dataflow_dialogue.turns 131 | ] 132 | 133 | _, mock_cheating_turn_indices = execute_programs_for_dialogue( 134 | dialogue=mock_dataflow_dialogue, 135 | salience_model=salience_model, 136 | no_revise=False, 137 | cheating_mode="dynamic", 138 | cheating_execution_results=cheating_execution_results, 139 | ) 140 | assert mock_cheating_turn_indices == [3, 4, 8, 10] 141 | -------------------------------------------------------------------------------- /tests/test_dataflow/multiwoz/test_ontology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from typing import Dict, List 4 | 5 | from dataflow.multiwoz.ontology import DATAFLOW_SLOT_NAMES_FOR_DOMAIN 6 | from dataflow.multiwoz.trade_dst_utils import ( 7 | get_domain_and_slot_name, 8 | normalize_trade_slot_name, 9 | trade_normalize_slot_name, 10 | ) 11 | 12 | 13 | def expected_dataflow_slot_names_for_domain() -> Dict[str, List[str]]: 14 | # extracted from MultiWoZ-2.1 ontology.json file 15 | # $ jq 'keys' ontology.json 16 | raw_slot_fullnames: List[str] = [ 17 | "attraction-area", 18 | "attraction-name", 19 | "attraction-type", 20 | "hotel-area", 21 | "hotel-book day", 22 | "hotel-book people", 23 | "hotel-book stay", 24 | "hotel-internet", 25 | "hotel-name", 26 | "hotel-parking", 27 | "hotel-price range", 28 | "hotel-stars", 29 | "hotel-type", 30 | "restaurant-area", 31 | "restaurant-book day", 32 | "restaurant-book people", 33 | "restaurant-book time", 34 | "restaurant-food", 35 | "restaurant-name", 36 | "restaurant-price range", 37 | "taxi-arrive by", 38 | "taxi-departure", 39 | "taxi-destination", 40 | "taxi-leave at", 41 | "train-arrive by", 42 | "train-book people", 43 | "train-day", 44 | "train-departure", 45 | "train-destination", 46 | "train-leave at", 47 | "bus-day", 48 | "bus-departure", 49 | "bus-destination", 50 | "bus-leaveAt", 51 | "hospital-department", 52 | ] 53 | 54 | dataflow_slot_fullnames_for_domain: Dict[str, List[str]] = dict() 55 | for slot_fullname in sorted(raw_slot_fullnames): 56 | slot_fullname = normalize_trade_slot_name( 57 | name=trade_normalize_slot_name(name=slot_fullname) 58 | ) 59 | domain, slot_name = get_domain_and_slot_name(slot_fullname=slot_fullname) 60 | if domain not in dataflow_slot_fullnames_for_domain: 61 | dataflow_slot_fullnames_for_domain[domain] = [] 62 | dataflow_slot_fullnames_for_domain[domain].append(slot_name) 63 | return dataflow_slot_fullnames_for_domain 64 | 65 | 66 | def test_dataflow_slot_fullnames_for_domain(): 67 | assert DATAFLOW_SLOT_NAMES_FOR_DOMAIN == expected_dataflow_slot_names_for_domain() 68 | --------------------------------------------------------------------------------