├── .gitignore ├── .travis.yml ├── AUTHORS.rst ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.rst ├── LICENSE ├── MANIFEST.in ├── README.rst ├── samples ├── cross_val_dataset.json ├── test_dataset.json └── train_dataset.json ├── setup.py ├── snips_nlu_metrics ├── __init__.py ├── __version__ ├── engine.py ├── metrics.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── mock_engine.py │ ├── resources │ │ ├── beverage_dataset.json │ │ └── keyword_matching_dataset.json │ ├── test_dataset_utils.py │ ├── test_exception.py │ ├── test_metrics.py │ └── test_metrics_utils.py └── utils │ ├── __init__.py │ ├── constants.py │ ├── dataset_utils.py │ ├── exception.py │ ├── metrics_utils.py │ └── temp_utils.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | venv2/ 89 | venv3/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | .idea/ 105 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | matrix: 4 | include: 5 | - python: 3.5 6 | env: TOXENV=py35 7 | - python: 3.6 8 | env: TOXENV=py36 9 | - python: 3.7 10 | env: TOXENV=py37 11 | - python: 3.8 12 | env: TOXENV=py38 13 | 14 | install: pip install --upgrade --pre tox 15 | 16 | script: tox -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | Snips NLU is written and maintained by Snips. 2 | 3 | Development Lead 4 | ================ 5 | 6 | * `Adrien Ball `_ 7 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | ## [0.15.0] - 2020-03-10 5 | ### Changed 6 | - Remove support of Python2.7 [#123](https://github.com/snipsco/snips-nlu-metrics/pull/123) 7 | 8 | ## [0.14.6] - 2020-01-14 9 | ### Added 10 | - Support for python3.8 [#121](https://github.com/snipsco/snips-nlu-metrics/pull/121) 11 | 12 | ## [0.14.5] - 2019-08-20 13 | ### Fixed 14 | - Fix issue with intents filter when dataset has not enough data [#118](https://github.com/snipsco/snips-nlu-metrics/pull/118) 15 | 16 | ## [0.14.4] - 2019-06-18 17 | ### Fixed 18 | - Update dependencies 19 | 20 | ## [0.14.3] - 2019-05-10 21 | ### Added 22 | - Add optional parameter `intents_filter` to metrics APIs [#115](https://github.com/snipsco/snips-nlu-metrics/pull/115) 23 | 24 | ## [0.14.2] - 2019-03-21 25 | ### Added 26 | - Number of exact parsings 27 | - Possibility to provide out-of-domain utterances 28 | - Logging 29 | 30 | ### Fixed 31 | - Hanging issue when using multiple workers, when one job returns a non-zero exit code 32 | 33 | ## [0.14.1] - 2019-01-07 34 | ### Added 35 | - Support for new NLU output format 36 | 37 | ### Fixed 38 | - Bug with None intent when computing average metrics 39 | 40 | ## [0.14.0] - 2018-11-13 41 | ### Added 42 | - Possibility to use parallel workers 43 | - Seed parameter for reproducibility 44 | - Average metrics for intent classification and slot filling 45 | 46 | ## [0.13.0] - 2018-07-25 47 | ### Fixed 48 | - Crash while computing metrics when either actual or predicted intent is unknown 49 | 50 | ### Removed 51 | - APIs depending implicitely on Snips NLU: 52 | - `compute_cross_val_nlu_metrics` 53 | - `compute_train_test_nlu_metrics` 54 | 55 | ### Changed 56 | - Use flexible version specifiers for dependencies 57 | 58 | 59 | ## [0.12.0] - 2018-03-29 60 | ### Added 61 | - F1 scores for intent classification and entity extraction 62 | - Confusion matrix 63 | - New option to exclude slot metrics in the output 64 | - Samples 65 | 66 | 67 | [0.15.0]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.6...0.15.0 68 | [0.14.6]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.5...0.14.6 69 | [0.14.5]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.4...0.14.5 70 | [0.14.4]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.3...0.14.4 71 | [0.14.3]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.2...0.14.3 72 | [0.14.2]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.1...0.14.2 73 | [0.14.1]: https://github.com/snipsco/snips-nlu-metrics/compare/0.14.0...0.14.1 74 | [0.14.0]: https://github.com/snipsco/snips-nlu-metrics/compare/0.13.0...0.14.0 75 | [0.13.0]: https://github.com/snipsco/snips-nlu-metrics/compare/0.12.0...0.13.0 76 | [0.12.0]: https://github.com/snipsco/snips-nlu-metrics/compare/0.11.1...0.12.0 -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at adrien.ball@snips.ai or clement.doumouro@snips.ai. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | How to Contribute 2 | ================= 3 | 4 | Contributions are welcome! Not familiar with the codebase yet? No problem! 5 | There are many ways to contribute to open source projects: reporting bugs, 6 | helping with the documentation, spreading the word and of course, adding 7 | new features and patches. 8 | 9 | Getting Started 10 | --------------- 11 | * Make sure you have a GitHub account. 12 | * Open a `new issue `_, assuming one does not already exist. 13 | * Clearly describe the issue including steps to reproduce when it is a bug. 14 | 15 | Making Changes 16 | -------------- 17 | * Fork this repository. 18 | * Create a feature branch from where you want to base your work. 19 | * Make commits of logical units (if needed rebase your feature branch before 20 | submitting it). 21 | * Check that your changes are `PEP8 `_ compliant (for instance using Pycharm or pylint). 22 | * Make sure your commit messages are well formatted. 23 | * If your commit fixes an open issue, reference it in the commit message (f.e. ``#15``). 24 | * Run all the tests (if existing) to assure nothing else was accidentally broken. 25 | 26 | These guidelines also apply when helping with documentation. 27 | 28 | Submitting Changes 29 | ------------------ 30 | * Push your changes to a feature branch in your fork of the repository. 31 | * Submit a ``Pull Request`` on the develop branch. 32 | * Wait for maintainer feedback. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include samples * 2 | include snips_nlu_metrics/__version__ 3 | include README.rst LICENSE CHANGELOG.md 4 | global-exclude __pycache__ *.py[cod] 5 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Snips NLU Metrics 2 | ================= 3 | 4 | .. image:: https://travis-ci.org/snipsco/snips-nlu-metrics.svg?branch=master 5 | :target: https://travis-ci.org/snipsco/snips-nlu-metrics 6 | 7 | .. image:: https://img.shields.io/pypi/v/snips-nlu-metrics.svg?branch=master 8 | :target: https://pypi.python.org/pypi/snips-nlu-metrics 9 | 10 | .. image:: https://img.shields.io/pypi/pyversions/snips-nlu-metrics.svg?branch=master 11 | :target: https://pypi.python.org/pypi/snips-nlu-metrics 12 | 13 | 14 | This tools is a python library for computing `cross-validation`_ and 15 | `train/test`_ metrics on an NLU parsing pipeline such as the `Snips NLU`_ one. 16 | 17 | Its purpose is to help evaluating and iterating on the tested intent parsing 18 | pipeline. 19 | 20 | Install 21 | ------- 22 | 23 | .. code-block:: console 24 | 25 | $ pip install snips_nlu_metrics 26 | 27 | 28 | NLU Metrics API 29 | --------------- 30 | 31 | Snips NLU metrics API consists in the following functions: 32 | 33 | * ``compute_train_test_metrics`` to compute `train/test`_ metrics 34 | * ``compute_cross_val_metrics`` to compute `cross-validation`_ metrics 35 | 36 | The metrics output (json) provides detailed information about: 37 | 38 | * `confusion matrix`_ 39 | * `precision, recall and f1 scores`_ of intent classification 40 | * precision, recall and f1 scores of entity extraction 41 | * parsing errors 42 | 43 | Data 44 | ---- 45 | 46 | Some sample datasets, that can be used to compute metrics, are available 47 | `here `_. Alternatively, you can create your own dataset either by 48 | using ``snips-nlu``'s `dataset generation tool`_ or by going on the 49 | `Snips console`_. 50 | 51 | Examples 52 | -------- 53 | 54 | The Snips NLU metrics library can be used with any NLU pipeline which satisfies 55 | the ``Engine`` API: 56 | 57 | .. code-block:: python 58 | 59 | class Engine: 60 | def fit(self, dataset): 61 | # Perform training ... 62 | return self 63 | 64 | def parse(self, text): 65 | # extract intent and slots ... 66 | return { 67 | "input": text, 68 | "intent": { 69 | "intentName": intent_name, 70 | "probability": probability 71 | }, 72 | "slots": slots 73 | } 74 | 75 | 76 | ---------------- 77 | Snips NLU Engine 78 | ---------------- 79 | 80 | This library can be used to benchmark NLU solutions such as `Snips NLU`_. To 81 | install the ``snips-nlu`` python library, and fetch the language resources for 82 | english, run the following commands: 83 | 84 | .. code-block:: bash 85 | 86 | $ pip install snips-nlu 87 | $ snips-nlu download en 88 | 89 | 90 | Then, you can compute metrics for the ``snips-nlu`` pipeline using the metrics 91 | API as follows: 92 | 93 | .. code-block:: python 94 | 95 | from snips_nlu import SnipsNLUEngine 96 | from snips_nlu_metrics import compute_train_test_metrics, compute_cross_val_metrics 97 | 98 | tt_metrics = compute_train_test_metrics(train_dataset="samples/train_dataset.json", 99 | test_dataset="samples/test_dataset.json", 100 | engine_class=SnipsNLUEngine) 101 | 102 | cv_metrics = compute_cross_val_metrics(dataset="samples/cross_val_dataset.json", 103 | engine_class=SnipsNLUEngine, 104 | nb_folds=5) 105 | 106 | ----------------- 107 | Custom NLU Engine 108 | ----------------- 109 | 110 | You can also compute metrics on a custom NLU engine, here is a simple example: 111 | 112 | .. code-block:: python 113 | 114 | import random 115 | 116 | from snips_nlu_metrics import compute_train_test_metrics 117 | 118 | class MyNLUEngine: 119 | def fit(self, dataset): 120 | self.intent_list = list(dataset["intents"]) 121 | return self 122 | 123 | def parse(self, text): 124 | return { 125 | "input": text, 126 | "intent": { 127 | "intentName": random.choice(self.intent_list), 128 | "probability": 0.5 129 | }, 130 | "slots": [] 131 | } 132 | 133 | compute_train_test_metrics(train_dataset="samples/train_dataset.json", 134 | test_dataset="samples/test_dataset.json", 135 | engine_class=MyNLUEngine) 136 | 137 | Links 138 | ----- 139 | * `Changelog `__ 140 | * `Bug tracker `__ 141 | * `Snips NLU `__ 142 | * `Snips NLU Rust `__: Rust inference pipeline implementation and bindings (C, Swift, Kotlin, Python) 143 | * `Snips `__ 144 | 145 | Contributing 146 | ------------ 147 | Please see the `Contribution Guidelines `_. 148 | 149 | Copyright 150 | --------- 151 | This library is provided by `Snips `_ as Open Source software. See `LICENSE `_ for more information. 152 | 153 | .. _cross-validation: https://en.wikipedia.org/wiki/Cross-validation_(statistics) 154 | .. _train/test: https://en.wikipedia.org/wiki/Training,_test,_and_validation_sets 155 | .. _Snips NLU: https://github.com/snipsco/snips-nlu 156 | .. _precision, recall and f1 scores: https://en.wikipedia.org/wiki/Precision_and_recall 157 | .. _confusion matrix: https://en.wikipedia.org/wiki/Confusion_matrix 158 | .. _dataset generation tool: http://snips-nlu.readthedocs.io/en/latest/tutorial.html#snips-dataset-format 159 | .. _Snips console: https://console.snips.ai -------------------------------------------------------------------------------- /samples/test_dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "intents": { 3 | "user_H1WQIL0U4e__SearchWeatherForecast": { 4 | "utterances": [ 5 | { 6 | "data": [ 7 | { 8 | "text": "what's the weather in " 9 | }, 10 | { 11 | "slot_name": "location", 12 | "text": "wagga wagga", 13 | "entity": "weather_location" 14 | }, 15 | { 16 | "text": "?" 17 | } 18 | ] 19 | }, 20 | { 21 | "data": [ 22 | { 23 | "text": "weather for " 24 | }, 25 | { 26 | "slot_name": "location", 27 | "text": "clicquot oregon", 28 | "entity": "weather_location" 29 | }, 30 | { 31 | "text": "?" 32 | } 33 | ] 34 | }, 35 | { 36 | "data": [ 37 | { 38 | "text": "what is the weather forecast for " 39 | }, 40 | { 41 | "slot_name": "location", 42 | "text": "greenfield", 43 | "entity": "weather_location" 44 | }, 45 | { 46 | "text": " " 47 | }, 48 | { 49 | "text": "on november the 4th, 2035", 50 | "slot_name": "datetime", 51 | "entity": "snips/datetime" 52 | }, 53 | { 54 | "text": "?" 55 | } 56 | ] 57 | }, 58 | { 59 | "data": [ 60 | { 61 | "text": "look for the weather " 62 | }, 63 | { 64 | "text": "at 10:21:35 am", 65 | "slot_name": "datetime", 66 | "entity": "snips/datetime" 67 | } 68 | ] 69 | }, 70 | { 71 | "data": [ 72 | { 73 | "text": "get the " 74 | }, 75 | { 76 | "slot_name": "location", 77 | "text": "jambughoda wildlife sanctuary", 78 | "entity": "weather_location" 79 | }, 80 | { 81 | "text": " weather forecast?" 82 | } 83 | ] 84 | }, 85 | { 86 | "data": [ 87 | { 88 | "text": "can i get the weather for " 89 | }, 90 | { 91 | "text": "Paris, Texas", 92 | "slot_name": "location", 93 | "entity": "weather_location" 94 | } 95 | ] 96 | }, 97 | { 98 | "data": [ 99 | { 100 | "text": "weather in " 101 | }, 102 | { 103 | "slot_name": "location", 104 | "text": "afi mountain wildlife sanctuary", 105 | "entity": "weather_location" 106 | }, 107 | { 108 | "text": "? " 109 | } 110 | ] 111 | }, 112 | { 113 | "data": [ 114 | { 115 | "text": "give me the " 116 | }, 117 | { 118 | "text": "feb. 1", 119 | "slot_name": "datetime", 120 | "entity": "snips/datetime" 121 | }, 122 | { 123 | "text": " forecast at " 124 | }, 125 | { 126 | "slot_name": "location", 127 | "text": "laguna mountain recreation area.", 128 | "entity": "weather_location" 129 | } 130 | ] 131 | }, 132 | { 133 | "data": [ 134 | { 135 | "text": "i want the weather in " 136 | }, 137 | { 138 | "slot_name": "location", 139 | "text": "piedras negras", 140 | "entity": "weather_location" 141 | }, 142 | { 143 | "text": "?" 144 | } 145 | ] 146 | }, 147 | { 148 | "data": [ 149 | { 150 | "text": "what will the weather be like " 151 | }, 152 | { 153 | "text": "tonight", 154 | "slot_name": "datetime", 155 | "entity": "snips/datetime" 156 | }, 157 | { 158 | "text": " in " 159 | }, 160 | { 161 | "text": "kassopaia", 162 | "slot_name": "location", 163 | "entity": "weather_location" 164 | }, 165 | { 166 | "text": "?" 167 | } 168 | ] 169 | }, 170 | { 171 | "data": [ 172 | { 173 | "text": "forecast " 174 | }, 175 | { 176 | "text": "at 4 p.m.", 177 | "slot_name": "datetime", 178 | "entity": "snips/datetime" 179 | }, 180 | { 181 | "text": " on " 182 | }, 183 | { 184 | "slot_name": "location", 185 | "text": "the banke national park", 186 | "entity": "weather_location" 187 | }, 188 | { 189 | "text": "?" 190 | } 191 | ] 192 | }, 193 | { 194 | "data": [ 195 | { 196 | "text": "weather forecast in " 197 | }, 198 | { 199 | "slot_name": "location", 200 | "text": "china", 201 | "entity": "weather_location" 202 | }, 203 | { 204 | "text": " " 205 | }, 206 | { 207 | "text": "on feb. the 21st, 2031", 208 | "slot_name": "datetime", 209 | "entity": "snips/datetime" 210 | }, 211 | { 212 | "text": "? " 213 | } 214 | ] 215 | }, 216 | { 217 | "data": [ 218 | { 219 | "text": "what's the forecast for " 220 | }, 221 | { 222 | "slot_name": "location", 223 | "text": "gifford", 224 | "entity": "weather_location" 225 | }, 226 | { 227 | "text": "? " 228 | } 229 | ] 230 | }, 231 | { 232 | "data": [ 233 | { 234 | "text": "show me a forecast for " 235 | }, 236 | { 237 | "slot_name": "location", 238 | "text": "lutts", 239 | "entity": "weather_location" 240 | }, 241 | { 242 | "text": " " 243 | }, 244 | { 245 | "text": "in 23 weeks", 246 | "slot_name": "datetime", 247 | "entity": "snips/datetime" 248 | }, 249 | { 250 | "text": " " 251 | } 252 | ] 253 | }, 254 | { 255 | "data": [ 256 | { 257 | "text": "get the weather for " 258 | }, 259 | { 260 | "text": "october the third, 2023", 261 | "slot_name": "datetime", 262 | "entity": "snips/datetime" 263 | }, 264 | { 265 | "text": "? " 266 | } 267 | ] 268 | }, 269 | { 270 | "data": [ 271 | { 272 | "text": "what will be the weather in " 273 | }, 274 | { 275 | "slot_name": "location", 276 | "text": "lannon", 277 | "entity": "weather_location" 278 | }, 279 | { 280 | "text": " " 281 | }, 282 | { 283 | "text": "at 2pm", 284 | "slot_name": "datetime", 285 | "entity": "snips/datetime" 286 | }, 287 | { 288 | "text": "?" 289 | } 290 | ] 291 | }, 292 | { 293 | "data": [ 294 | { 295 | "text": "get the " 296 | }, 297 | { 298 | "slot_name": "location", 299 | "text": "joshua tree national park", 300 | "entity": "weather_location" 301 | }, 302 | { 303 | "text": " weather forecast " 304 | }, 305 | { 306 | "slot_name": "datetime", 307 | "text": "at elevenses", 308 | "entity": "snips/datetime" 309 | }, 310 | { 311 | "text": "?" 312 | } 313 | ] 314 | } 315 | ] 316 | }, 317 | "user_H1WQIL0U4e__GetWeatherForecastCondition": { 318 | "utterances": [ 319 | { 320 | "data": [ 321 | { 322 | "text": "will there be a " 323 | }, 324 | { 325 | "slot_name": "weather_condition", 326 | "text": "depression", 327 | "entity": "weather_condition" 328 | }, 329 | { 330 | "text": " " 331 | }, 332 | { 333 | "text": "at 11 a.m.", 334 | "slot_name": "datetime", 335 | "entity": "snips/datetime" 336 | }, 337 | { 338 | "text": " in " 339 | }, 340 | { 341 | "slot_name": "location", 342 | "text": "arkansas", 343 | "entity": "weather_location" 344 | } 345 | ] 346 | }, 347 | { 348 | "data": [ 349 | { 350 | "text": "is it going to be " 351 | }, 352 | { 353 | "slot_name": "weather_condition", 354 | "text": "snowy", 355 | "entity": "weather_condition" 356 | }, 357 | { 358 | "text": " in " 359 | }, 360 | { 361 | "slot_name": "location", 362 | "text": "rio negrinho", 363 | "entity": "weather_location" 364 | }, 365 | { 366 | "text": " " 367 | }, 368 | { 369 | "text": "this saturday", 370 | "slot_name": "datetime", 371 | "entity": "snips/datetime" 372 | }, 373 | { 374 | "text": "?" 375 | } 376 | ] 377 | }, 378 | { 379 | "data": [ 380 | { 381 | "text": "was there " 382 | }, 383 | { 384 | "slot_name": "weather_condition", 385 | "text": "snowfall", 386 | "entity": "weather_condition" 387 | }, 388 | { 389 | "text": " " 390 | }, 391 | { 392 | "text": "last february", 393 | "slot_name": "datetime", 394 | "entity": "snips/datetime" 395 | }, 396 | { 397 | "text": "?" 398 | } 399 | ] 400 | }, 401 | { 402 | "data": [ 403 | { 404 | "text": "when will it " 405 | }, 406 | { 407 | "slot_name": "weather_condition", 408 | "text": "snow", 409 | "entity": "weather_condition" 410 | }, 411 | { 412 | "text": " in " 413 | }, 414 | { 415 | "slot_name": "location", 416 | "text": "bistum srikakulam", 417 | "entity": "weather_location" 418 | }, 419 | { 420 | "text": "?" 421 | } 422 | ] 423 | }, 424 | { 425 | "data": [ 426 | { 427 | "text": "will it be " 428 | }, 429 | { 430 | "slot_name": "weather_condition", 431 | "text": "stormy", 432 | "entity": "weather_condition" 433 | }, 434 | { 435 | "text": " in " 436 | }, 437 | { 438 | "slot_name": "location", 439 | "text": "mp", 440 | "entity": "weather_location" 441 | }, 442 | { 443 | "text": "?" 444 | } 445 | ] 446 | }, 447 | { 448 | "data": [ 449 | { 450 | "text": "will it be " 451 | }, 452 | { 453 | "slot_name": "weather_condition", 454 | "text": "stormy", 455 | "entity": "weather_condition" 456 | }, 457 | { 458 | "text": " " 459 | }, 460 | { 461 | "text": "at 18:07:49", 462 | "slot_name": "datetime", 463 | "entity": "snips/datetime" 464 | }, 465 | { 466 | "text": "?" 467 | } 468 | ] 469 | }, 470 | { 471 | "data": [ 472 | { 473 | "text": "will there be " 474 | }, 475 | { 476 | "slot_name": "weather_condition", 477 | "text": "sun", 478 | "entity": "weather_condition" 479 | }, 480 | { 481 | "text": " " 482 | }, 483 | { 484 | "slot_name": "datetime", 485 | "text": "at seven am", 486 | "entity": "snips/datetime" 487 | }, 488 | { 489 | "text": " in " 490 | }, 491 | { 492 | "slot_name": "location", 493 | "text": "the grand duchy of finland", 494 | "entity": "weather_location" 495 | }, 496 | { 497 | "text": "? " 498 | } 499 | ] 500 | }, 501 | { 502 | "data": [ 503 | { 504 | "text": "will it be high " 505 | }, 506 | { 507 | "slot_name": "weather_condition", 508 | "text": "humidity", 509 | "entity": "weather_condition" 510 | }, 511 | { 512 | "text": " " 513 | }, 514 | { 515 | "slot_name": "datetime", 516 | "text": "last mar.", 517 | "entity": "snips/datetime" 518 | }, 519 | { 520 | "text": " in " 521 | }, 522 | { 523 | "slot_name": "location", 524 | "text": "ga", 525 | "entity": "weather_location" 526 | }, 527 | { 528 | "text": "?" 529 | } 530 | ] 531 | }, 532 | { 533 | "data": [ 534 | { 535 | "text": "is it " 536 | }, 537 | { 538 | "slot_name": "weather_condition", 539 | "text": "foggy", 540 | "entity": "weather_condition" 541 | }, 542 | { 543 | "text": " " 544 | }, 545 | { 546 | "text": "now", 547 | "slot_name": "datetime", 548 | "entity": "snips/datetime" 549 | }, 550 | { 551 | "text": " in " 552 | }, 553 | { 554 | "slot_name": "location", 555 | "text": "kingdom of mapungubwe", 556 | "entity": "weather_location" 557 | }, 558 | { 559 | "text": "?" 560 | } 561 | ] 562 | }, 563 | { 564 | "data": [ 565 | { 566 | "text": "will there be " 567 | }, 568 | { 569 | "slot_name": "weather_condition", 570 | "text": "wind", 571 | "entity": "weather_condition" 572 | }, 573 | { 574 | "text": " blowing " 575 | }, 576 | { 577 | "slot_name": "location", 578 | "text": "around me", 579 | "entity": "weather_location" 580 | }, 581 | { 582 | "text": "?" 583 | } 584 | ] 585 | }, 586 | { 587 | "data": [ 588 | { 589 | "text": "will it be " 590 | }, 591 | { 592 | "slot_name": "weather_condition", 593 | "text": "windy", 594 | "entity": "weather_condition" 595 | }, 596 | { 597 | "text": " " 598 | }, 599 | { 600 | "text": "at 18:16", 601 | "slot_name": "datetime", 602 | "entity": "snips/datetime" 603 | }, 604 | { 605 | "text": " in " 606 | }, 607 | { 608 | "slot_name": "location", 609 | "text": "the upland island wilderness", 610 | "entity": "weather_location" 611 | }, 612 | { 613 | "text": "? " 614 | } 615 | ] 616 | }, 617 | { 618 | "data": [ 619 | { 620 | "text": "will it be " 621 | }, 622 | { 623 | "slot_name": "weather_condition", 624 | "text": "sunny", 625 | "entity": "weather_condition" 626 | }, 627 | { 628 | "text": " in " 629 | }, 630 | { 631 | "slot_name": "location", 632 | "text": "the ne", 633 | "entity": "weather_location" 634 | }, 635 | { 636 | "text": " during " 637 | }, 638 | { 639 | "slot_name": "datetime", 640 | "text": "brunch", 641 | "entity": "snips/datetime" 642 | }, 643 | { 644 | "text": "? " 645 | } 646 | ] 647 | }, 648 | { 649 | "data": [ 650 | { 651 | "text": "will there be a " 652 | }, 653 | { 654 | "slot_name": "weather_condition", 655 | "text": "storm", 656 | "entity": "weather_condition" 657 | }, 658 | { 659 | "text": " in " 660 | }, 661 | { 662 | "slot_name": "location", 663 | "text": "alacalufes national reserve", 664 | "entity": "weather_location" 665 | }, 666 | { 667 | "text": "?" 668 | } 669 | ] 670 | }, 671 | { 672 | "data": [ 673 | { 674 | "text": "will there be " 675 | }, 676 | { 677 | "slot_name": "weather_condition", 678 | "text": "fog", 679 | "entity": "weather_condition" 680 | }, 681 | { 682 | "text": " in " 683 | }, 684 | { 685 | "slot_name": "location", 686 | "text": "kingdom of sanwi", 687 | "entity": "weather_location" 688 | }, 689 | { 690 | "text": " " 691 | } 692 | ] 693 | }, 694 | { 695 | "data": [ 696 | { 697 | "text": "is it " 698 | }, 699 | { 700 | "slot_name": "weather_condition", 701 | "text": "snowy", 702 | "entity": "weather_condition" 703 | }, 704 | { 705 | "text": " in " 706 | }, 707 | { 708 | "slot_name": "location", 709 | "text": "pawling nature reserve", 710 | "entity": "weather_location" 711 | } 712 | ] 713 | }, 714 | { 715 | "data": [ 716 | { 717 | "text": "will the weather be " 718 | }, 719 | { 720 | "slot_name": "weather_condition", 721 | "text": "overcast", 722 | "entity": "weather_condition" 723 | }, 724 | { 725 | "text": " at " 726 | }, 727 | { 728 | "slot_name": "datetime", 729 | "text": "brunch", 730 | "entity": "snips/datetime" 731 | }, 732 | { 733 | "text": " in " 734 | }, 735 | { 736 | "slot_name": "location", 737 | "text": "il", 738 | "entity": "weather_location" 739 | }, 740 | { 741 | "text": "?" 742 | } 743 | ] 744 | } 745 | ] 746 | }, 747 | "user_H1WQIL0U4e__GetWeatherTemperature": { 748 | "utterances": [ 749 | { 750 | "data": [ 751 | { 752 | "text": "is it cold in " 753 | }, 754 | { 755 | "slot_name": "location", 756 | "text": "british cameroons", 757 | "entity": "weather_location" 758 | }, 759 | { 760 | "text": "?" 761 | } 762 | ] 763 | }, 764 | { 765 | "data": [ 766 | { 767 | "text": "will it be temperate in " 768 | }, 769 | { 770 | "slot_name": "location", 771 | "text": "fort myers villas", 772 | "entity": "weather_location" 773 | } 774 | ] 775 | }, 776 | { 777 | "data": [ 778 | { 779 | "text": "what will the temperature be in " 780 | }, 781 | { 782 | "slot_name": "location", 783 | "text": "bangladesh", 784 | "entity": "weather_location" 785 | }, 786 | { 787 | "text": "?" 788 | } 789 | ] 790 | }, 791 | { 792 | "data": [ 793 | { 794 | "text": "what will the temperature by " 795 | }, 796 | { 797 | "slot_name": "datetime", 798 | "text": "on 1/12/2039", 799 | "entity": "snips/datetime" 800 | }, 801 | { 802 | "text": " " 803 | } 804 | ] 805 | }, 806 | { 807 | "data": [ 808 | { 809 | "text": "will it be cold in " 810 | }, 811 | { 812 | "slot_name": "location", 813 | "text": "russian gulch state park", 814 | "entity": "weather_location" 815 | } 816 | ] 817 | }, 818 | { 819 | "data": [ 820 | { 821 | "text": "what will the temperature be " 822 | }, 823 | { 824 | "text": "at three pm", 825 | "slot_name": "datetime", 826 | "entity": "snips/datetime" 827 | }, 828 | { 829 | "text": " in " 830 | }, 831 | { 832 | "slot_name": "location", 833 | "text": "clontarf virginia", 834 | "entity": "weather_location" 835 | } 836 | ] 837 | }, 838 | { 839 | "data": [ 840 | { 841 | "text": "will it be cold in " 842 | }, 843 | { 844 | "slot_name": "location", 845 | "text": "marstrand free port", 846 | "entity": "weather_location" 847 | }, 848 | { 849 | "text": "?" 850 | } 851 | ] 852 | }, 853 | { 854 | "data": [ 855 | { 856 | "text": "will it be cold in " 857 | }, 858 | { 859 | "slot_name": "location", 860 | "text": "tula oblast", 861 | "entity": "weather_location" 862 | }, 863 | { 864 | "text": "?" 865 | } 866 | ] 867 | }, 868 | { 869 | "data": [ 870 | { 871 | "text": "will the temperature be high in " 872 | }, 873 | { 874 | "slot_name": "location", 875 | "text": "british togoland", 876 | "entity": "weather_location" 877 | }, 878 | { 879 | "text": "?" 880 | } 881 | ] 882 | }, 883 | { 884 | "data": [ 885 | { 886 | "text": "will it be hot " 887 | }, 888 | { 889 | "text": "at 3 am", 890 | "slot_name": "datetime", 891 | "entity": "snips/datetime" 892 | }, 893 | { 894 | "text": " in " 895 | }, 896 | { 897 | "slot_name": "location", 898 | "text": "missouri mines state historic site", 899 | "entity": "weather_location" 900 | }, 901 | { 902 | "text": "?" 903 | } 904 | ] 905 | }, 906 | { 907 | "data": [ 908 | { 909 | "text": "is the weather in the " 910 | }, 911 | { 912 | "slot_name": "location", 913 | "text": "captaincy general of the philippines", 914 | "entity": "weather_location" 915 | }, 916 | { 917 | "text": " warm? " 918 | } 919 | ] 920 | }, 921 | { 922 | "data": [ 923 | { 924 | "slot_name": "location", 925 | "text": "muddus national park", 926 | "entity": "weather_location" 927 | }, 928 | { 929 | "text": " " 930 | }, 931 | { 932 | "text": "on november 23rd", 933 | "slot_name": "datetime", 934 | "entity": "snips/datetime" 935 | }, 936 | { 937 | "text": " whats the temperature " 938 | } 939 | ] 940 | }, 941 | { 942 | "data": [ 943 | { 944 | "text": "will it be hot " 945 | }, 946 | { 947 | "text": "on the 18", 948 | "slot_name": "datetime", 949 | "entity": "snips/datetime" 950 | }, 951 | { 952 | "text": " in " 953 | }, 954 | { 955 | "slot_name": "location", 956 | "text": "florida", 957 | "entity": "weather_location" 958 | }, 959 | { 960 | "text": "? " 961 | } 962 | ] 963 | }, 964 | { 965 | "data": [ 966 | { 967 | "text": "will it be cold " 968 | }, 969 | { 970 | "text": "at 07:09:43", 971 | "slot_name": "datetime", 972 | "entity": "snips/datetime" 973 | }, 974 | { 975 | "text": " " 976 | } 977 | ] 978 | }, 979 | { 980 | "data": [ 981 | { 982 | "text": "what will the temperature be " 983 | }, 984 | { 985 | "text": "at 01:46 a.m.", 986 | "slot_name": "datetime", 987 | "entity": "snips/datetime" 988 | }, 989 | { 990 | "text": " in " 991 | }, 992 | { 993 | "slot_name": "location", 994 | "text": "nationalpark vesuv", 995 | "entity": "weather_location" 996 | }, 997 | { 998 | "text": "? " 999 | } 1000 | ] 1001 | }, 1002 | { 1003 | "data": [ 1004 | { 1005 | "text": "will it be chilly " 1006 | }, 1007 | { 1008 | "text": "at 12 pm", 1009 | "slot_name": "datetime", 1010 | "entity": "snips/datetime" 1011 | }, 1012 | { 1013 | "text": " in " 1014 | }, 1015 | { 1016 | "slot_name": "location", 1017 | "text": "montana", 1018 | "entity": "weather_location" 1019 | }, 1020 | { 1021 | "text": " " 1022 | } 1023 | ] 1024 | }, 1025 | { 1026 | "data": [ 1027 | { 1028 | "text": "will it be colder " 1029 | }, 1030 | { 1031 | "text": "19 seconds from now", 1032 | "slot_name": "datetime", 1033 | "entity": "snips/datetime" 1034 | }, 1035 | { 1036 | "text": " " 1037 | } 1038 | ] 1039 | } 1040 | ] 1041 | } 1042 | }, 1043 | "entities": { 1044 | "weather_location": { 1045 | "use_synonyms": false, 1046 | "data": [], 1047 | "automatically_extensible": true 1048 | }, 1049 | "snips/datetime": {}, 1050 | "weather_condition": { 1051 | "use_synonyms": true, 1052 | "data": [ 1053 | { 1054 | "value": "wind", 1055 | "synonyms": [ 1056 | "windy" 1057 | ] 1058 | }, 1059 | { 1060 | "value": "rain", 1061 | "synonyms": [ 1062 | "rainy", 1063 | "umbrella", 1064 | "shower", 1065 | "rainfall" 1066 | ] 1067 | }, 1068 | { 1069 | "value": "storm", 1070 | "synonyms": [ 1071 | "stormy" 1072 | ] 1073 | }, 1074 | { 1075 | "value": "sun", 1076 | "synonyms": [ 1077 | "sunny" 1078 | ] 1079 | }, 1080 | { 1081 | "value": "snow", 1082 | "synonyms": [ 1083 | "snowy", 1084 | "snowstorm", 1085 | "snowfall" 1086 | ] 1087 | }, 1088 | { 1089 | "value": "thunder", 1090 | "synonyms": [ 1091 | "lightnings", 1092 | "thunderstorm" 1093 | ] 1094 | }, 1095 | { 1096 | "value": "cloud", 1097 | "synonyms": [ 1098 | "cloudy" 1099 | ] 1100 | } 1101 | ], 1102 | "automatically_extensible": true 1103 | } 1104 | }, 1105 | "language": "en" 1106 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | packages = [p for p in find_packages() if "tests" not in p] 6 | 7 | PACKAGE_NAME = "snips_nlu_metrics" 8 | ROOT_PATH = Path(__file__).resolve().parent 9 | PACKAGE_PATH = ROOT_PATH / PACKAGE_NAME 10 | README = ROOT_PATH / "README.rst" 11 | VERSION = "__version__" 12 | 13 | with (PACKAGE_PATH / VERSION).open() as f: 14 | version = f.readline().strip() 15 | 16 | with README.open(encoding="utf8") as f: 17 | readme = f.read() 18 | 19 | install_requires = [ 20 | "numpy>=1.7,<2.0", 21 | "scipy>=1.0,<2.0", 22 | "scikit-learn>=0.21.0,<0.23; python_version>='3.5'", 23 | "joblib>=0.13,<0.15", 24 | ] 25 | 26 | extras_require = {"test": ["mock>=2.0,<3.0", "pytest>=5.3.1,<6",]} 27 | 28 | setup( 29 | name=PACKAGE_NAME, 30 | description="Python package to compute NLU metrics", 31 | long_description=readme, 32 | version=version, 33 | author="Adrien Ball", 34 | author_email="adrien.ball@snips.ai", 35 | license="Apache 2.0", 36 | url="https://github.com/snipsco/snips-nlu-metrics", 37 | classifiers=[ 38 | "Programming Language :: Python :: 3", 39 | "Programming Language :: Python :: 3.5", 40 | "Programming Language :: Python :: 3.6", 41 | "Programming Language :: Python :: 3.7", 42 | "Programming Language :: Python :: 3.8", 43 | ], 44 | keywords="metrics nlu nlp intent slots entity parsing", 45 | extras_require=extras_require, 46 | install_requires=install_requires, 47 | packages=packages, 48 | include_package_data=True, 49 | zip_safe=False, 50 | ) 51 | -------------------------------------------------------------------------------- /snips_nlu_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from snips_nlu_metrics.engine import Engine 2 | from snips_nlu_metrics.metrics import ( 3 | compute_train_test_metrics, 4 | compute_cross_val_metrics, 5 | ) 6 | -------------------------------------------------------------------------------- /snips_nlu_metrics/__version__: -------------------------------------------------------------------------------- 1 | 0.15.0 -------------------------------------------------------------------------------- /snips_nlu_metrics/engine.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Engine(metaclass=ABCMeta): 5 | """Abstract class which represents an engine that can be used in the 6 | metrics API. All engine classes must inherit from `Engine`. 7 | """ 8 | 9 | @abstractmethod 10 | def fit(self, dataset): 11 | pass 12 | 13 | @abstractmethod 14 | def parse(self, text): 15 | pass 16 | -------------------------------------------------------------------------------- /snips_nlu_metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | 5 | from joblib import Parallel, delayed 6 | 7 | from snips_nlu_metrics.utils.constants import ( 8 | AVERAGE_METRICS, 9 | CONFUSION_MATRIX, 10 | INTENTS, 11 | INTENT_UTTERANCES, 12 | METRICS, 13 | PARSING_ERRORS, 14 | UTTERANCES, 15 | ) 16 | from snips_nlu_metrics.utils.exception import NotEnoughDataError 17 | from snips_nlu_metrics.utils.metrics_utils import ( 18 | aggregate_matrices, 19 | aggregate_metrics, 20 | compute_average_metrics, 21 | compute_engine_metrics, 22 | compute_precision_recall_f1, 23 | compute_split_metrics, 24 | create_shuffle_stratified_splits, 25 | ) 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def compute_cross_val_metrics( 31 | dataset, 32 | engine_class, 33 | nb_folds=5, 34 | train_size_ratio=1.0, 35 | drop_entities=False, 36 | include_slot_metrics=True, 37 | slot_matching_lambda=None, 38 | progression_handler=None, 39 | num_workers=1, 40 | seed=None, 41 | out_of_domain_utterances=None, 42 | intents_filter=None, 43 | ): 44 | """Compute end-to-end metrics on the dataset using cross validation 45 | 46 | Args: 47 | dataset (dict or str): dataset or path to dataset 48 | engine_class: python class to use for training and inference, this 49 | class must inherit from `Engine` 50 | nb_folds (int, optional): number of folds to use for cross validation 51 | (default=5) 52 | train_size_ratio (float, optional): ratio of intent utterances to use 53 | for training (default=1.0) 54 | drop_entities (bool, optional): specify whether or not all entity 55 | values should be removed from training data (default=False) 56 | include_slot_metrics (bool, optional): if false, the slots metrics and 57 | the slots parsing errors will not be reported (default=True) 58 | slot_matching_lambda (lambda, optional): 59 | lambda expected_slot, actual_slot -> bool, 60 | if defined, this function will be use to match slots when computing 61 | metrics, otherwise exact match will be used. 62 | `expected_slot` corresponds to the slot as defined in the dataset, 63 | and `actual_slot` corresponds to the slot as returned by the NLU 64 | default(None) 65 | progression_handler (lambda, optional): handler called at each 66 | progression (%) step (default=None) 67 | num_workers (int, optional): number of workers to use. Each worker 68 | is assigned a certain number of splits (default=1) 69 | seed (int, optional): seed for the split creation 70 | out_of_domain_utterances (list, optional): if defined, list of 71 | out-of-domain utterances to be added to the pool of test utterances 72 | in each split 73 | intents_filter (list of str, optional): if defined, at inference times 74 | test utterances will be restricted to the ones belonging to this 75 | filter. Moreover, if the parsing API allows it, the inference will 76 | be made using this intents filter. 77 | 78 | Returns: 79 | dict: Metrics results containing the following data 80 | 81 | - "metrics": the computed metrics 82 | - "parsing_errors": the list of parsing errors 83 | - "confusion_matrix": the computed confusion matrix 84 | - "average_metrics": the metrics averaged over all intents 85 | """ 86 | 87 | if isinstance(dataset, (str, Path)): 88 | with Path(dataset).open(encoding="utf8") as f: 89 | dataset = json.load(f) 90 | 91 | try: 92 | splits = create_shuffle_stratified_splits( 93 | dataset, 94 | nb_folds, 95 | train_size_ratio, 96 | drop_entities, 97 | seed, 98 | out_of_domain_utterances, 99 | intents_filter, 100 | ) 101 | except NotEnoughDataError as e: 102 | logger.warning("Not enough data, skipping metrics computation: %r", e) 103 | return { 104 | AVERAGE_METRICS: None, 105 | CONFUSION_MATRIX: None, 106 | METRICS: None, 107 | PARSING_ERRORS: [], 108 | } 109 | 110 | intent_list = sorted(list(dataset["intents"])) 111 | global_metrics = dict() 112 | global_confusion_matrix = None 113 | global_errors = [] 114 | total_splits = len(splits) 115 | 116 | def compute_metrics(split_): 117 | logger.info("Computing metrics for dataset split ...") 118 | return compute_split_metrics( 119 | engine_class, 120 | split_, 121 | intent_list, 122 | include_slot_metrics, 123 | slot_matching_lambda, 124 | intents_filter, 125 | ) 126 | 127 | effective_num_workers = min(num_workers, len(splits)) 128 | if effective_num_workers > 1: 129 | parallel = Parallel(n_jobs=effective_num_workers) 130 | results = parallel(delayed(compute_metrics)(split) for split in splits) 131 | else: 132 | results = [compute_metrics(s) for s in splits] 133 | 134 | for result in enumerate(results): 135 | split_index, (split_metrics, errors, confusion_matrix) = result 136 | global_metrics = aggregate_metrics( 137 | global_metrics, split_metrics, include_slot_metrics 138 | ) 139 | global_confusion_matrix = aggregate_matrices( 140 | global_confusion_matrix, confusion_matrix 141 | ) 142 | global_errors += errors 143 | logger.info("Done computing %d/%d splits" % (split_index + 1, total_splits)) 144 | 145 | if progression_handler is not None: 146 | progression_handler(float(split_index + 1) / float(total_splits)) 147 | 148 | global_metrics = compute_precision_recall_f1(global_metrics) 149 | 150 | average_metrics = compute_average_metrics( 151 | global_metrics, 152 | ignore_none_intent=True if out_of_domain_utterances is None else False, 153 | ) 154 | 155 | nb_utterances = { 156 | intent: len(data[UTTERANCES]) for intent, data in dataset[INTENTS].items() 157 | } 158 | for intent, metrics in global_metrics.items(): 159 | metrics[INTENT_UTTERANCES] = nb_utterances.get(intent, 0) 160 | 161 | return { 162 | CONFUSION_MATRIX: global_confusion_matrix, 163 | AVERAGE_METRICS: average_metrics, 164 | METRICS: global_metrics, 165 | PARSING_ERRORS: global_errors, 166 | } 167 | 168 | 169 | def compute_train_test_metrics( 170 | train_dataset, 171 | test_dataset, 172 | engine_class, 173 | include_slot_metrics=True, 174 | slot_matching_lambda=None, 175 | intents_filter=None, 176 | ): 177 | """Compute end-to-end metrics on `test_dataset` after having trained on 178 | `train_dataset` 179 | 180 | Args: 181 | train_dataset (dict or str): Dataset or path to dataset used for 182 | training 183 | test_dataset (dict or str): dataset or path to dataset used for testing 184 | engine_class: Python class to use for training and inference, this 185 | class must inherit from `Engine` 186 | include_slot_metrics (bool, true): If false, the slots metrics and the 187 | slots parsing errors will not be reported. 188 | slot_matching_lambda (lambda, optional): 189 | lambda expected_slot, actual_slot -> bool, 190 | if defined, this function will be use to match slots when computing 191 | metrics, otherwise exact match will be used. 192 | `expected_slot` corresponds to the slot as defined in the dataset, 193 | and `actual_slot` corresponds to the slot as returned by the NLU 194 | intents_filter (list of str, optional): if defined, at inference times 195 | test utterances will be restricted to the ones belonging to this 196 | filter. Moreover, if the parsing API allows it, the inference will 197 | be made using this intents filter. 198 | 199 | Returns 200 | dict: Metrics results containing the following data 201 | 202 | - "metrics": the computed metrics 203 | - "parsing_errors": the list of parsing errors 204 | - "confusion_matrix": the computed confusion matrix 205 | - "average_metrics": the metrics averaged over all intents 206 | """ 207 | 208 | if isinstance(train_dataset, (str, Path)): 209 | with Path(train_dataset).open(encoding="utf8") as f: 210 | train_dataset = json.load(f) 211 | 212 | if isinstance(test_dataset, (str, Path)): 213 | with Path(test_dataset).open(encoding="utf8") as f: 214 | test_dataset = json.load(f) 215 | 216 | intent_list = set(train_dataset["intents"]) 217 | intent_list.update(test_dataset["intents"]) 218 | intent_list = sorted(intent_list) 219 | 220 | logger.info("Training engine...") 221 | engine = engine_class() 222 | engine.fit(train_dataset) 223 | test_utterances = [ 224 | (intent_name, utterance) 225 | for intent_name, intent_data in test_dataset[INTENTS].items() 226 | for utterance in intent_data[UTTERANCES] 227 | if intents_filter is None or intent_name in intents_filter 228 | ] 229 | 230 | logger.info("Computing metrics...") 231 | metrics, errors, confusion_matrix = compute_engine_metrics( 232 | engine, 233 | test_utterances, 234 | intent_list, 235 | include_slot_metrics, 236 | slot_matching_lambda, 237 | intents_filter, 238 | ) 239 | metrics = compute_precision_recall_f1(metrics) 240 | average_metrics = compute_average_metrics(metrics) 241 | nb_utterances = { 242 | intent: len(data[UTTERANCES]) for intent, data in train_dataset[INTENTS].items() 243 | } 244 | for intent, intent_metrics in metrics.items(): 245 | intent_metrics[INTENT_UTTERANCES] = nb_utterances.get(intent, 0) 246 | return { 247 | CONFUSION_MATRIX: confusion_matrix, 248 | AVERAGE_METRICS: average_metrics, 249 | METRICS: metrics, 250 | PARSING_ERRORS: errors, 251 | } 252 | -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snipsco/snips-nlu-metrics/34a6d0baea9f1adc739423249762be62ba4e27cc/snips_nlu_metrics/tests/__init__.py -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def resources_path(): 11 | return Path(__file__).resolve().parent / "resources" 12 | 13 | 14 | @pytest.fixture(scope="session") 15 | def beverage_dataset_path(resources_path): 16 | return resources_path / "beverage_dataset.json" 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def beverage_dataset(resources_path): 21 | with (resources_path / "beverage_dataset.json").open(encoding="utf8") as f: 22 | return json.load(f) 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def keyword_matching_dataset(resources_path): 27 | with (resources_path / "keyword_matching_dataset.json").open(encoding="utf8") as f: 28 | return json.load(f) 29 | 30 | 31 | @pytest.fixture(scope="module") 32 | def logger(): 33 | logger = logging.getLogger("snips_nlu_metrics") 34 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 35 | handler = logging.StreamHandler(sys.stdout) 36 | handler.setLevel(logging.DEBUG) 37 | handler.setFormatter(formatter) 38 | logger.addHandler(handler) 39 | logger.setLevel(logging.DEBUG) 40 | return logger 41 | -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/mock_engine.py: -------------------------------------------------------------------------------- 1 | from snips_nlu_metrics import Engine 2 | 3 | 4 | def dummy_parsing_result(text, intent_name=None): 5 | return { 6 | "input": text, 7 | "intent": {"intentName": intent_name, "probability": 0.5}, 8 | "slots": [], 9 | } 10 | 11 | 12 | class MockEngine(Engine): 13 | def __init__(self): 14 | self.fitted = False 15 | 16 | def fit(self, dataset): 17 | self.fitted = True 18 | 19 | def parse(self, text): 20 | return dummy_parsing_result(text) 21 | 22 | 23 | class KeyWordMatchingEngine(Engine): 24 | def __init__(self): 25 | self.fitted = False 26 | self.intents_list = [] 27 | 28 | def fit(self, dataset): 29 | self.fitted = True 30 | self.intents_list = sorted(dataset["intents"]) 31 | 32 | def parse(self, text, intents_filter=None): 33 | intent = None 34 | for intent_name in self.intents_list: 35 | if intent_name in text: 36 | intent = intent_name 37 | break 38 | if intents_filter is not None and intent not in intents_filter: 39 | intent = None 40 | return dummy_parsing_result(text, intent) 41 | 42 | 43 | class MockEngineSegfault(Engine): 44 | def __init__(self): 45 | self.fitted = False 46 | 47 | def fit(self, dataset): 48 | self.fitted = True 49 | 50 | def parse(self, text): 51 | # Simulate a segmentation fault 52 | exit(139) 53 | -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/resources/beverage_dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "language": "en", 3 | "intents": { 4 | "MakeCoffee": { 5 | "utterances": [ 6 | { 7 | "data": [ 8 | { 9 | "text": "Make me " 10 | }, 11 | { 12 | "text": "five thousands and one", 13 | "slot_name": "number_of_cups", 14 | "entity": "snips/number" 15 | }, 16 | { 17 | "text": " cups of coffee" 18 | } 19 | ] 20 | }, 21 | { 22 | "data": [ 23 | { 24 | "text": "Brew me " 25 | }, 26 | { 27 | "text": "two hundreds", 28 | "slot_name": "number_of_cups", 29 | "entity": "snips/number" 30 | }, 31 | { 32 | "text": " coffee cups" 33 | } 34 | ] 35 | }, 36 | { 37 | "data": [ 38 | { 39 | "text": "make me " 40 | }, 41 | { 42 | "text": "twenty three", 43 | "slot_name": "number_of_cups", 44 | "entity": "snips/number" 45 | }, 46 | { 47 | "text": " cups of coffee" 48 | } 49 | ] 50 | }, 51 | { 52 | "data": [ 53 | { 54 | "text": "Brew " 55 | }, 56 | { 57 | "text": "four", 58 | "slot_name": "number_of_cups", 59 | "entity": "snips/number" 60 | }, 61 | { 62 | "text": " coffee cups" 63 | } 64 | ] 65 | }, 66 | { 67 | "data": [ 68 | { 69 | "text": "Make me " 70 | }, 71 | { 72 | "text": "12", 73 | "slot_name": "number_of_cups", 74 | "entity": "snips/number" 75 | }, 76 | { 77 | "text": " cups of coffee" 78 | } 79 | ] 80 | }, 81 | { 82 | "data": [ 83 | { 84 | "text": "give me " 85 | }, 86 | { 87 | "text": "six", 88 | "slot_name": "number_of_cups", 89 | "entity": "snips/number" 90 | }, 91 | { 92 | "text": " coffees please" 93 | } 94 | ] 95 | }, 96 | { 97 | "data": [ 98 | { 99 | "text": "Make me " 100 | }, 101 | { 102 | "text": "one", 103 | "slot_name": "number_of_cups", 104 | "entity": "snips/number" 105 | }, 106 | { 107 | "text": " cup of coffee" 108 | } 109 | ] 110 | } 111 | ] 112 | }, 113 | "MakeTea": { 114 | "utterances": [ 115 | { 116 | "data": [ 117 | { 118 | "text": "Make " 119 | }, 120 | { 121 | "text": "twenty two", 122 | "slot_name": "number_of_cups", 123 | "entity": "snips/number" 124 | }, 125 | { 126 | "text": " " 127 | }, 128 | { 129 | "text": "iced", 130 | "slot_name": "beverage_temperature", 131 | "entity": "Temperature" 132 | }, 133 | { 134 | "text": " teas" 135 | } 136 | ] 137 | }, 138 | { 139 | "data": [ 140 | { 141 | "text": "Please, can I get " 142 | }, 143 | { 144 | "text": "five", 145 | "slot_name": "number_of_cups", 146 | "entity": "snips/number" 147 | }, 148 | { 149 | "text": " " 150 | }, 151 | { 152 | "text": "cold", 153 | "slot_name": "beverage_temperature", 154 | "entity": "Temperature" 155 | }, 156 | { 157 | "text": " teas ?" 158 | } 159 | ] 160 | }, 161 | { 162 | "data": [ 163 | { 164 | "text": "Prepare " 165 | }, 166 | { 167 | "text": "three", 168 | "slot_name": "number_of_cups", 169 | "entity": "snips/number" 170 | }, 171 | { 172 | "text": " cups of " 173 | }, 174 | { 175 | "text": "hot", 176 | "slot_name": "beverage_temperature", 177 | "entity": "Temperature" 178 | }, 179 | { 180 | "text": " tea" 181 | } 182 | ] 183 | }, 184 | { 185 | "data": [ 186 | { 187 | "text": "Make me " 188 | }, 189 | { 190 | "text": "one", 191 | "slot_name": "number_of_cups", 192 | "entity": "snips/number" 193 | }, 194 | { 195 | "text": " tea" 196 | } 197 | ] 198 | } 199 | ] 200 | } 201 | }, 202 | "entities": { 203 | "snips/number": {}, 204 | "Temperature": { 205 | "use_synonyms": true, 206 | "automatically_extensible": true, 207 | "data": [ 208 | { 209 | "value": "cold", 210 | "synonyms": [ 211 | "cold", 212 | "iced" 213 | ] 214 | }, 215 | { 216 | "value": "hot", 217 | "synonyms": [ 218 | "hot", 219 | "boiling" 220 | ] 221 | } 222 | ] 223 | } 224 | } 225 | } -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/resources/keyword_matching_dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "language": "en", 3 | "intents": { 4 | "intent1": { 5 | "utterances": [ 6 | { 7 | "data": [ 8 | { 9 | "text": "intent1 utterance" 10 | } 11 | ] 12 | }, 13 | { 14 | "data": [ 15 | { 16 | "text": "intent1 utterance" 17 | } 18 | ] 19 | }, 20 | { 21 | "data": [ 22 | { 23 | "text": "intent1 utterance" 24 | } 25 | ] 26 | }, 27 | { 28 | "data": [ 29 | { 30 | "text": "intent1 utterance" 31 | } 32 | ] 33 | } 34 | ] 35 | }, 36 | "intent2": { 37 | "utterances": [ 38 | { 39 | "data": [ 40 | { 41 | "text": "intent2 utterance" 42 | } 43 | ] 44 | }, 45 | { 46 | "data": [ 47 | { 48 | "text": "intent2 utterance" 49 | } 50 | ] 51 | }, 52 | { 53 | "data": [ 54 | { 55 | "text": "intent2 utterance" 56 | } 57 | ] 58 | }, 59 | { 60 | "data": [ 61 | { 62 | "text": "intent2 intent1 ambiguous utterance" 63 | } 64 | ] 65 | } 66 | ] 67 | }, 68 | "intent3": { 69 | "utterances": [ 70 | { 71 | "data": [ 72 | { 73 | "text": "intent3 utterance" 74 | } 75 | ] 76 | }, 77 | { 78 | "data": [ 79 | { 80 | "text": "intent3 utterance" 81 | } 82 | ] 83 | }, 84 | { 85 | "data": [ 86 | { 87 | "text": "intent3 intent1 ambiguous utterance" 88 | } 89 | ] 90 | } 91 | ] 92 | } 93 | }, 94 | "entities": {} 95 | } -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/test_dataset_utils.py: -------------------------------------------------------------------------------- 1 | from snips_nlu_metrics.utils.dataset_utils import ( 2 | get_utterances_subset, 3 | update_entities_with_utterances, 4 | ) 5 | 6 | 7 | def test_get_utterances_subset_should_work(): 8 | # Given 9 | utterances = [ 10 | ("intent1", {"data": [{"text": "text1"}]}), 11 | ("intent1", {"data": [{"text": "text2"}]}), 12 | ("intent1", {"data": [{"text": "text3"}]}), 13 | ("intent1", {"data": [{"text": "text4"}]}), 14 | ("intent2", {"data": [{"text": "text1"}]}), 15 | ("intent2", {"data": [{"text": "text2"}]}), 16 | ("intent3", {"data": [{"text": "text1"}]}), 17 | ("intent3", {"data": [{"text": "text2"}]}), 18 | ("intent3", {"data": [{"text": "text3"}]}), 19 | ("intent3", {"data": [{"text": "text4"}]}), 20 | ("intent3", {"data": [{"text": "text5"}]}), 21 | ("intent3", {"data": [{"text": "text6"}]}), 22 | ] 23 | 24 | # When 25 | utterances_subset = get_utterances_subset(utterances, ratio=0.5) 26 | utterances_subset = sorted( 27 | utterances_subset, key=lambda u: "%s%s" % (u[0], u[1]["data"][0]["text"]) 28 | ) 29 | 30 | # Then 31 | expected_utterances = [ 32 | ("intent1", {"data": [{"text": "text1"}]}), 33 | ("intent1", {"data": [{"text": "text2"}]}), 34 | ("intent2", {"data": [{"text": "text1"}]}), 35 | ("intent3", {"data": [{"text": "text1"}]}), 36 | ("intent3", {"data": [{"text": "text2"}]}), 37 | ("intent3", {"data": [{"text": "text3"}]}), 38 | ] 39 | assert expected_utterances == utterances_subset 40 | 41 | 42 | def test_update_entities_with_utterances(): 43 | # Given 44 | dataset = { 45 | "intents": { 46 | "intent_1": { 47 | "utterances": [ 48 | { 49 | "data": [ 50 | {"text": "aa", "entity": "entity_2"}, 51 | {"text": "bb", "entity": "entity_2"}, 52 | ] 53 | } 54 | ] 55 | }, 56 | "intent_2": { 57 | "utterances": [{"data": [{"text": "cccc", "entity": "entity_1"}]}] 58 | }, 59 | }, 60 | "entities": { 61 | "entity_1": {"data": [], "use_synonyms": False}, 62 | "entity_2": { 63 | "data": [{"value": "a", "synonyms": ["aa"]}], 64 | "use_synonyms": True, 65 | }, 66 | }, 67 | } 68 | # When 69 | updated_dataset = update_entities_with_utterances(dataset) 70 | 71 | # Then 72 | expected_dataset = { 73 | "intents": { 74 | "intent_1": { 75 | "utterances": [ 76 | { 77 | "data": [ 78 | {"text": "aa", "entity": "entity_2"}, 79 | {"text": "bb", "entity": "entity_2"}, 80 | ] 81 | } 82 | ] 83 | }, 84 | "intent_2": { 85 | "utterances": [{"data": [{"text": "cccc", "entity": "entity_1"}]}] 86 | }, 87 | }, 88 | "entities": { 89 | "entity_1": { 90 | "data": [{"value": "cccc", "synonyms": []}], 91 | "use_synonyms": False, 92 | }, 93 | "entity_2": { 94 | "data": [ 95 | {"value": "a", "synonyms": ["aa"]}, 96 | {"value": "bb", "synonyms": []}, 97 | ], 98 | "use_synonyms": True, 99 | }, 100 | }, 101 | } 102 | assert expected_dataset == updated_dataset 103 | -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/test_exception.py: -------------------------------------------------------------------------------- 1 | from snips_nlu_metrics.utils.exception import NotEnoughDataError 2 | 3 | 4 | def test_not_enough_data_error_repr(): 5 | # Given 6 | dataset = { 7 | "intents": { 8 | "intents_1": {"utterances": 5 * [{"data": [{"text": "foobar"}]}]}, 9 | "intents_2": {"utterances": 7 * [{"data": [{"text": "foobar"}]}]}, 10 | }, 11 | "entities": dict(), 12 | "language": "en", 13 | } 14 | error = NotEnoughDataError(dataset=dataset, nb_folds=4, train_size_ratio=0.5) 15 | 16 | # When 17 | repr_error = repr(error) 18 | error_message = error.message 19 | 20 | # Then 21 | assert ( 22 | "nb folds = 4, train size ratio = 0.5, " 23 | "intents details = [intents_1 -> 5 utterances, " 24 | "intents_2 -> 7 utterances]" == repr_error 25 | ) 26 | assert ( 27 | "Not enough data: nb folds = 4, " 28 | "train size ratio = 0.5, intents details = " 29 | "[intents_1 -> 5 utterances, " 30 | "intents_2 -> 7 utterances]" == error_message 31 | ) 32 | -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from snips_nlu_metrics.metrics import ( 4 | compute_cross_val_metrics, 5 | compute_train_test_metrics, 6 | ) 7 | from snips_nlu_metrics.tests.mock_engine import ( 8 | MockEngine, 9 | MockEngineSegfault, 10 | KeyWordMatchingEngine, 11 | ) 12 | from snips_nlu_metrics.utils.constants import ( 13 | METRICS, 14 | PARSING_ERRORS, 15 | CONFUSION_MATRIX, 16 | AVERAGE_METRICS, 17 | ) 18 | 19 | 20 | def test_compute_cross_val_metrics(logger, beverage_dataset): 21 | # When/Then 22 | try: 23 | res = compute_cross_val_metrics( 24 | dataset=beverage_dataset, engine_class=MockEngine, nb_folds=2 25 | ) 26 | except Exception as e: 27 | raise AssertionError(e.args[0]) 28 | 29 | expected_metrics = { 30 | "null": { 31 | "intent": { 32 | "true_positive": 0, 33 | "false_positive": 11, 34 | "false_negative": 0, 35 | "precision": 0.0, 36 | "recall": 0.0, 37 | "f1": 0.0, 38 | }, 39 | "exact_parsings": 0, 40 | "slots": {}, 41 | "intent_utterances": 0, 42 | }, 43 | "MakeCoffee": { 44 | "intent": { 45 | "true_positive": 0, 46 | "false_positive": 0, 47 | "false_negative": 7, 48 | "precision": 0.0, 49 | "recall": 0.0, 50 | "f1": 0.0, 51 | }, 52 | "exact_parsings": 0, 53 | "slots": { 54 | "number_of_cups": { 55 | "true_positive": 0, 56 | "false_positive": 0, 57 | "false_negative": 0, 58 | "precision": 0.0, 59 | "recall": 0.0, 60 | "f1": 0.0, 61 | } 62 | }, 63 | "intent_utterances": 7, 64 | }, 65 | "MakeTea": { 66 | "intent": { 67 | "true_positive": 0, 68 | "false_positive": 0, 69 | "false_negative": 4, 70 | "precision": 0.0, 71 | "recall": 0.0, 72 | "f1": 0.0, 73 | }, 74 | "exact_parsings": 0, 75 | "slots": { 76 | "number_of_cups": { 77 | "true_positive": 0, 78 | "false_positive": 0, 79 | "false_negative": 0, 80 | "precision": 0.0, 81 | "recall": 0.0, 82 | "f1": 0.0, 83 | }, 84 | "beverage_temperature": { 85 | "true_positive": 0, 86 | "false_positive": 0, 87 | "false_negative": 0, 88 | "precision": 0.0, 89 | "recall": 0.0, 90 | "f1": 0.0, 91 | }, 92 | }, 93 | "intent_utterances": 4, 94 | }, 95 | } 96 | 97 | assert expected_metrics, res["metrics"] 98 | 99 | 100 | def test_compute_cross_val_metrics_with_intents_filter( 101 | logger, keyword_matching_dataset 102 | ): 103 | # When/Then 104 | res = compute_cross_val_metrics( 105 | dataset=keyword_matching_dataset, 106 | engine_class=KeyWordMatchingEngine, 107 | nb_folds=2, 108 | intents_filter=["intent2", "intent3"], 109 | include_slot_metrics=False, 110 | seed=42, 111 | ) 112 | 113 | expected_metrics = { 114 | "null": { 115 | "intent": { 116 | "true_positive": 0, 117 | "false_positive": 2, 118 | "false_negative": 0, 119 | "precision": 0.0, 120 | "recall": 0.0, 121 | "f1": 0.0, 122 | }, 123 | "exact_parsings": 0, 124 | "intent_utterances": 0, 125 | }, 126 | "intent2": { 127 | "intent": { 128 | "true_positive": 3, 129 | "false_positive": 0, 130 | "false_negative": 1, 131 | "precision": 1.0, 132 | "recall": 3.0 / 4.0, 133 | "f1": 0.8571428571428571, 134 | }, 135 | "exact_parsings": 3, 136 | "intent_utterances": 4, 137 | }, 138 | "intent3": { 139 | "intent": { 140 | "true_positive": 2, 141 | "false_positive": 0, 142 | "false_negative": 1, 143 | "precision": 1.0, 144 | "recall": 2.0 / 3.0, 145 | "f1": 0.8, 146 | }, 147 | "exact_parsings": 2, 148 | "intent_utterances": 3, 149 | }, 150 | } 151 | 152 | assert expected_metrics, res["metrics"] 153 | 154 | 155 | def test_compute_cross_val_metrics_with_multiple_workers(logger, beverage_dataset): 156 | # When/Then 157 | expected_metrics = { 158 | "null": { 159 | "intent": { 160 | "true_positive": 0, 161 | "false_positive": 11, 162 | "false_negative": 0, 163 | "precision": 0.0, 164 | "recall": 0.0, 165 | "f1": 0.0, 166 | }, 167 | "exact_parsings": 0, 168 | "slots": {}, 169 | "intent_utterances": 0, 170 | }, 171 | "MakeCoffee": { 172 | "intent": { 173 | "true_positive": 0, 174 | "false_positive": 0, 175 | "false_negative": 7, 176 | "precision": 0.0, 177 | "recall": 0.0, 178 | "f1": 0.0, 179 | }, 180 | "exact_parsings": 0, 181 | "slots": { 182 | "number_of_cups": { 183 | "true_positive": 0, 184 | "false_positive": 0, 185 | "false_negative": 0, 186 | "precision": 0.0, 187 | "recall": 0.0, 188 | "f1": 0.0, 189 | } 190 | }, 191 | "intent_utterances": 7, 192 | }, 193 | "MakeTea": { 194 | "intent": { 195 | "true_positive": 0, 196 | "false_positive": 0, 197 | "false_negative": 4, 198 | "precision": 0.0, 199 | "recall": 0.0, 200 | "f1": 0.0, 201 | }, 202 | "exact_parsings": 0, 203 | "slots": { 204 | "number_of_cups": { 205 | "true_positive": 0, 206 | "false_positive": 0, 207 | "false_negative": 0, 208 | "precision": 0.0, 209 | "recall": 0.0, 210 | "f1": 0.0, 211 | }, 212 | "beverage_temperature": { 213 | "true_positive": 0, 214 | "false_positive": 0, 215 | "false_negative": 0, 216 | "precision": 0.0, 217 | "recall": 0.0, 218 | "f1": 0.0, 219 | }, 220 | }, 221 | "intent_utterances": 4, 222 | }, 223 | } 224 | try: 225 | res = compute_cross_val_metrics( 226 | dataset=beverage_dataset, engine_class=MockEngine, nb_folds=2, num_workers=4 227 | ) 228 | except Exception as e: 229 | raise AssertionError(e.args[0]) 230 | assert expected_metrics, res["metrics"] 231 | 232 | 233 | def test_should_raise_when_non_zero_exit(logger, beverage_dataset): 234 | # When/Then 235 | with pytest.raises(SystemExit): 236 | compute_cross_val_metrics( 237 | dataset=beverage_dataset, 238 | engine_class=MockEngineSegfault, 239 | nb_folds=4, 240 | num_workers=4, 241 | ) 242 | 243 | 244 | def test_compute_cross_val_metrics_without_slot_metrics(logger, beverage_dataset): 245 | # When/Then 246 | try: 247 | res = compute_cross_val_metrics( 248 | dataset=beverage_dataset, 249 | engine_class=MockEngine, 250 | nb_folds=2, 251 | include_slot_metrics=False, 252 | ) 253 | except Exception as e: 254 | raise AssertionError(e.args[0]) 255 | 256 | expected_metrics = { 257 | "null": { 258 | "intent": { 259 | "true_positive": 0, 260 | "false_positive": 11, 261 | "false_negative": 0, 262 | "precision": 0.0, 263 | "recall": 0.0, 264 | "f1": 0.0, 265 | }, 266 | "intent_utterances": 0, 267 | "exact_parsings": 0, 268 | }, 269 | "MakeCoffee": { 270 | "intent": { 271 | "true_positive": 0, 272 | "false_positive": 0, 273 | "false_negative": 7, 274 | "precision": 0.0, 275 | "recall": 0.0, 276 | "f1": 0.0, 277 | }, 278 | "intent_utterances": 7, 279 | "exact_parsings": 0, 280 | }, 281 | "MakeTea": { 282 | "intent": { 283 | "true_positive": 0, 284 | "false_positive": 0, 285 | "false_negative": 4, 286 | "precision": 0.0, 287 | "recall": 0.0, 288 | "f1": 0.0, 289 | }, 290 | "intent_utterances": 4, 291 | "exact_parsings": 0, 292 | }, 293 | } 294 | 295 | assert expected_metrics, res["metrics"] 296 | 297 | 298 | def test_cross_val_metrics_should_skip_when_not_enough_data( 299 | logger, beverage_dataset_path 300 | ): 301 | # When 302 | result = compute_cross_val_metrics( 303 | dataset=beverage_dataset_path, engine_class=MockEngine, nb_folds=11 304 | ) 305 | 306 | # Then 307 | expected_result = { 308 | AVERAGE_METRICS: None, 309 | CONFUSION_MATRIX: None, 310 | METRICS: None, 311 | PARSING_ERRORS: [], 312 | } 313 | assert expected_result, result 314 | 315 | 316 | def test_compute_train_test_metrics(logger, beverage_dataset): 317 | # When/Then 318 | try: 319 | res = compute_train_test_metrics( 320 | train_dataset=beverage_dataset, 321 | test_dataset=beverage_dataset, 322 | engine_class=MockEngine, 323 | ) 324 | except Exception as e: 325 | raise AssertionError(e.args[0]) 326 | 327 | expected_metrics = { 328 | "MakeCoffee": { 329 | "intent": { 330 | "true_positive": 0, 331 | "false_positive": 0, 332 | "false_negative": 7, 333 | "precision": 0.0, 334 | "recall": 0.0, 335 | "f1": 0.0, 336 | }, 337 | "slots": { 338 | "number_of_cups": { 339 | "true_positive": 0, 340 | "false_positive": 0, 341 | "false_negative": 0, 342 | "precision": 0.0, 343 | "recall": 0.0, 344 | "f1": 0.0, 345 | } 346 | }, 347 | "intent_utterances": 7, 348 | "exact_parsings": 0, 349 | }, 350 | "null": { 351 | "intent": { 352 | "true_positive": 0, 353 | "false_positive": 11, 354 | "false_negative": 0, 355 | "precision": 0.0, 356 | "recall": 0.0, 357 | "f1": 0.0, 358 | }, 359 | "slots": {}, 360 | "intent_utterances": 0, 361 | "exact_parsings": 0, 362 | }, 363 | "MakeTea": { 364 | "intent": { 365 | "true_positive": 0, 366 | "false_positive": 0, 367 | "false_negative": 4, 368 | "precision": 0.0, 369 | "recall": 0.0, 370 | "f1": 0.0, 371 | }, 372 | "slots": { 373 | "number_of_cups": { 374 | "true_positive": 0, 375 | "false_positive": 0, 376 | "false_negative": 0, 377 | "precision": 0.0, 378 | "recall": 0.0, 379 | "f1": 0.0, 380 | }, 381 | "beverage_temperature": { 382 | "true_positive": 0, 383 | "false_positive": 0, 384 | "false_negative": 0, 385 | "precision": 0.0, 386 | "recall": 0.0, 387 | "f1": 0.0, 388 | }, 389 | }, 390 | "intent_utterances": 4, 391 | "exact_parsings": 0, 392 | }, 393 | } 394 | 395 | assert expected_metrics, res["metrics"] 396 | 397 | 398 | def test_compute_train_test_metrics_without_slots_metrics(logger, beverage_dataset): 399 | # When/Then 400 | try: 401 | res = compute_train_test_metrics( 402 | train_dataset=beverage_dataset, 403 | test_dataset=beverage_dataset, 404 | engine_class=MockEngine, 405 | include_slot_metrics=False, 406 | ) 407 | except Exception as e: 408 | raise AssertionError(e.args[0]) 409 | 410 | expected_metrics = { 411 | "MakeCoffee": { 412 | "intent": { 413 | "true_positive": 0, 414 | "false_positive": 0, 415 | "false_negative": 7, 416 | "precision": 0.0, 417 | "recall": 0.0, 418 | "f1": 0.0, 419 | }, 420 | "intent_utterances": 7, 421 | "exact_parsings": 0, 422 | }, 423 | "null": { 424 | "intent": { 425 | "true_positive": 0, 426 | "false_positive": 11, 427 | "false_negative": 0, 428 | "precision": 0.0, 429 | "recall": 0.0, 430 | "f1": 0.0, 431 | }, 432 | "intent_utterances": 0, 433 | "exact_parsings": 0, 434 | }, 435 | "MakeTea": { 436 | "intent": { 437 | "true_positive": 0, 438 | "false_positive": 0, 439 | "false_negative": 4, 440 | "precision": 0.0, 441 | "recall": 0.0, 442 | "f1": 0.0, 443 | }, 444 | "intent_utterances": 4, 445 | "exact_parsings": 0, 446 | }, 447 | } 448 | 449 | assert expected_metrics, res["metrics"] 450 | -------------------------------------------------------------------------------- /snips_nlu_metrics/tests/test_metrics_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from snips_nlu_metrics.utils.constants import ( 4 | TRUE_POSITIVE, 5 | FALSE_POSITIVE, 6 | FALSE_NEGATIVE, 7 | TEXT, 8 | ) 9 | from snips_nlu_metrics.utils.exception import NotEnoughDataError 10 | from snips_nlu_metrics.utils.metrics_utils import ( 11 | aggregate_metrics, 12 | compute_utterance_metrics, 13 | compute_precision_recall_f1, 14 | exact_match, 15 | contains_errors, 16 | compute_engine_metrics, 17 | aggregate_matrices, 18 | create_shuffle_stratified_splits, 19 | ) 20 | 21 | 22 | def test_should_compute_engine_metrics(): 23 | # Given 24 | def create_utterance(intent_name, slot_name, slot_value): 25 | utterance = { 26 | "data": [ 27 | {"text": "this is an utterance with ",}, 28 | {"text": slot_value, "slot_name": slot_name, "entity": slot_name}, 29 | ] 30 | } 31 | return intent_name, utterance 32 | 33 | def create_parsing_output(intent_name, slot_name, slot_value): 34 | return { 35 | "text": "this is an utterance with %s" % slot_value, 36 | "intent": {"intentName": intent_name, "probability": 1.0}, 37 | "slots": [ 38 | { 39 | "rawValue": slot_value, 40 | "range": {"start": 26, "end": 26 + len(slot_value)}, 41 | "entity": slot_name, 42 | "slotName": slot_name, 43 | } 44 | ], 45 | } 46 | 47 | utterances = [ 48 | create_utterance("intent1", "slot1", "value1"), 49 | create_utterance("intent1", "slot1", "value2"), 50 | create_utterance("intent1", "slot2", "value3"), 51 | create_utterance("intent2", "slot3", "value4"), 52 | create_utterance("intent2", "slot3", "value5"), 53 | ] 54 | 55 | class TestEngine: 56 | def __init__(self): 57 | self.utterance_index = 0 58 | 59 | def parse(self, text): 60 | res = None 61 | if self.utterance_index == 0: 62 | res = create_parsing_output("intent1", "slot1", "value1") 63 | if self.utterance_index == 1: 64 | res = create_parsing_output("intent2", "slot3", "value4") 65 | if self.utterance_index == 2: 66 | res = create_parsing_output("intent1", "slot1", "value1") 67 | if self.utterance_index == 3: 68 | res = create_parsing_output("intent2", "slot3", "value4") 69 | if self.utterance_index == 4: 70 | res = create_parsing_output("intent2", "slot3", "value4") 71 | self.utterance_index += 1 72 | return res 73 | 74 | engine = TestEngine() 75 | 76 | def slots_match(lhs, rhs): 77 | return lhs[TEXT] == rhs["rawValue"] 78 | 79 | # When 80 | metrics, errors, confusion_matrix = compute_engine_metrics( 81 | engine=engine, 82 | test_utterances=utterances, 83 | intent_list=["intent1", "intent2"], 84 | include_slot_metrics=True, 85 | slot_matching_lambda=slots_match, 86 | intents_filter=None, 87 | ) 88 | 89 | # Then 90 | expected_metrics = { 91 | "intent1": { 92 | "exact_parsings": 1, 93 | "slots": { 94 | "slot1": {"false_positive": 1, "true_positive": 1, "false_negative": 0}, 95 | "slot2": {"false_positive": 0, "true_positive": 0, "false_negative": 1}, 96 | }, 97 | "intent": {"false_positive": 0, "true_positive": 2, "false_negative": 1}, 98 | }, 99 | "intent2": { 100 | "exact_parsings": 1, 101 | "slots": { 102 | "slot3": {"false_positive": 1, "true_positive": 1, "false_negative": 1} 103 | }, 104 | "intent": {"false_positive": 1, "true_positive": 2, "false_negative": 0}, 105 | }, 106 | } 107 | expected_errors = [ 108 | { 109 | "expected_output": { 110 | "input": "this is an utterance with value2", 111 | "slots": [ 112 | { 113 | "range": {"start": 26, "end": 32}, 114 | "slotName": "slot1", 115 | "rawValue": "value2", 116 | "entity": "slot1", 117 | } 118 | ], 119 | "intent": {"intentName": "intent1", "probability": 1.0}, 120 | }, 121 | "nlu_output": { 122 | "text": "this is an utterance with value4", 123 | "slots": [ 124 | { 125 | "slotName": "slot3", 126 | "range": {"start": 26, "end": 32}, 127 | "rawValue": "value4", 128 | "entity": "slot3", 129 | } 130 | ], 131 | "intent": {"intentName": "intent2", "probability": 1.0}, 132 | }, 133 | }, 134 | { 135 | "expected_output": { 136 | "input": "this is an utterance with value3", 137 | "slots": [ 138 | { 139 | "range": {"start": 26, "end": 32}, 140 | "slotName": "slot2", 141 | "rawValue": "value3", 142 | "entity": "slot2", 143 | } 144 | ], 145 | "intent": {"intentName": "intent1", "probability": 1.0}, 146 | }, 147 | "nlu_output": { 148 | "text": "this is an utterance with value1", 149 | "slots": [ 150 | { 151 | "slotName": "slot1", 152 | "range": {"start": 26, "end": 32}, 153 | "rawValue": "value1", 154 | "entity": "slot1", 155 | } 156 | ], 157 | "intent": {"intentName": "intent1", "probability": 1.0}, 158 | }, 159 | }, 160 | { 161 | "expected_output": { 162 | "input": "this is an utterance with value5", 163 | "slots": [ 164 | { 165 | "range": {"start": 26, "end": 32}, 166 | "slotName": "slot3", 167 | "rawValue": "value5", 168 | "entity": "slot3", 169 | } 170 | ], 171 | "intent": {"intentName": "intent2", "probability": 1.0}, 172 | }, 173 | "nlu_output": { 174 | "text": "this is an utterance with value4", 175 | "slots": [ 176 | { 177 | "slotName": "slot3", 178 | "range": {"start": 26, "end": 32}, 179 | "rawValue": "value4", 180 | "entity": "slot3", 181 | } 182 | ], 183 | "intent": {"intentName": "intent2", "probability": 1.0}, 184 | }, 185 | }, 186 | ] 187 | 188 | expected_confusion_matrix = { 189 | "intents": ["intent1", "intent2", "null"], 190 | "matrix": [[2, 1, 0], [0, 2, 0], [0, 0, 0],], 191 | } 192 | 193 | assert expected_metrics == metrics 194 | assert expected_errors == errors 195 | assert expected_confusion_matrix == confusion_matrix 196 | 197 | 198 | def test_should_compute_engine_metrics_with_intents_filter(): 199 | # Given 200 | def create_utterance(intent_name, text): 201 | return intent_name, {"data": [{"text": text}]} 202 | 203 | def create_parsing_output(intent_name, text): 204 | return { 205 | "text": text, 206 | "intent": {"intentName": intent_name, "probability": 1.0}, 207 | "slots": [], 208 | } 209 | 210 | utterances = [ 211 | create_utterance("intent1", "first utterance intent1"), 212 | create_utterance("intent1", "second utterance intent1"), 213 | create_utterance("intent1", "third utterance intent1"), 214 | create_utterance("intent1", "ambiguous utterance intent1 and intent3"), 215 | create_utterance("intent2", "first utterance intent2"), 216 | create_utterance("intent2", "second utterance intent2"), 217 | create_utterance("intent2", "ambiguous utterance intent2 and intent3"), 218 | ] 219 | 220 | class EngineWithFilterAPI: 221 | def parse(self, text, intents_filter=None): 222 | intent = None 223 | for intent_name in ["intent3", "intent1", "intent2"]: 224 | if intent_name in text: 225 | intent = intent_name 226 | break 227 | 228 | if intents_filter is not None and intent not in intents_filter: 229 | intent = None 230 | return create_parsing_output(intent, text) 231 | 232 | class EngineWithFilterProp: 233 | def __init__(self): 234 | self.intents_filter = ["intent1", "intent2"] 235 | 236 | def parse(self, text): 237 | intent = None 238 | for intent_name in ["intent3", "intent1", "intent2"]: 239 | if intent_name in text: 240 | intent = intent_name 241 | break 242 | 243 | if self.intents_filter is not None and intent not in self.intents_filter: 244 | intent = None 245 | return create_parsing_output(intent, text) 246 | 247 | engine_with_filter_api = EngineWithFilterAPI() 248 | engine_with_filter_prop = EngineWithFilterProp() 249 | 250 | # When 251 | metrics1, _, _ = compute_engine_metrics( 252 | engine=engine_with_filter_api, 253 | test_utterances=utterances, 254 | intent_list=["intent1", "intent2", "intent3"], 255 | include_slot_metrics=False, 256 | intents_filter=["intent1", "intent2"], 257 | ) 258 | metrics2, _, _ = compute_engine_metrics( 259 | engine=engine_with_filter_prop, 260 | test_utterances=utterances, 261 | intent_list=["intent1", "intent2", "intent3"], 262 | include_slot_metrics=False, 263 | intents_filter=["intent1", "intent2"], 264 | ) 265 | 266 | # Then 267 | expected_metrics = { 268 | "intent1": { 269 | "exact_parsings": 3, 270 | "intent": {"false_positive": 0, "true_positive": 3, "false_negative": 1}, 271 | }, 272 | "intent2": { 273 | "exact_parsings": 2, 274 | "intent": {"false_positive": 0, "true_positive": 2, "false_negative": 1,}, 275 | }, 276 | "null": { 277 | "exact_parsings": 0, 278 | "intent": {"false_positive": 2, "true_positive": 0, "false_negative": 0,}, 279 | }, 280 | } 281 | 282 | assert expected_metrics == metrics1 283 | assert expected_metrics == metrics2 284 | 285 | 286 | def test_should_compute_utterance_metrics_when_wrong_intent(): 287 | # Given 288 | actual_intent = "intent1" 289 | actual_slots = [] 290 | predicted_intent = "intent2" 291 | predicted_slots = [ 292 | { 293 | "rawValue": "utterance", 294 | "value": {"kind": "Custom", "value": "utterance"}, 295 | "range": {"start": 0, "end": 9}, 296 | "entity": "erroneous_entity", 297 | "slotName": "erroneous_slot", 298 | } 299 | ] 300 | 301 | # When 302 | metrics = compute_utterance_metrics( 303 | predicted_intent, 304 | predicted_slots, 305 | actual_intent, 306 | actual_slots, 307 | True, 308 | exact_match, 309 | ) 310 | # Then 311 | expected_metrics = { 312 | "intent1": { 313 | "intent": {"false_negative": 1, "false_positive": 0, "true_positive": 0}, 314 | "slots": {}, 315 | }, 316 | "intent2": { 317 | "intent": {"false_negative": 0, "false_positive": 1, "true_positive": 0}, 318 | "slots": { 319 | "erroneous_slot": { 320 | "false_negative": 0, 321 | "false_positive": 0, 322 | "true_positive": 0, 323 | } 324 | }, 325 | }, 326 | } 327 | 328 | assert expected_metrics == metrics 329 | 330 | 331 | def test_should_compute_utterance_metrics_when_correct_intent(): 332 | # Given 333 | actual_intent = "intent1" 334 | actual_slots = [ 335 | {"text": "slot1_value", "entity": "entity1", "slot_name": "slot1"}, 336 | {"text": "slot2_value", "entity": "entity2", "slot_name": "slot2"}, 337 | ] 338 | predicted_intent = actual_intent 339 | predicted_slots = [ 340 | { 341 | "rawValue": "slot1_value", 342 | "value": {"kind": "Custom", "value": "slot1_value"}, 343 | "range": {"start": 21, "end": 32}, 344 | "entity": "entity1", 345 | "slotName": "slot1", 346 | } 347 | ] 348 | 349 | # When 350 | metrics = compute_utterance_metrics( 351 | predicted_intent, 352 | predicted_slots, 353 | actual_intent, 354 | actual_slots, 355 | True, 356 | exact_match, 357 | ) 358 | # Then 359 | expected_metrics = { 360 | "intent1": { 361 | "intent": {"false_negative": 0, "false_positive": 0, "true_positive": 1}, 362 | "slots": { 363 | "slot1": {"false_negative": 0, "false_positive": 0, "true_positive": 1}, 364 | "slot2": {"false_negative": 1, "false_positive": 0, "true_positive": 0}, 365 | }, 366 | } 367 | } 368 | 369 | assert expected_metrics == metrics 370 | 371 | 372 | def test_should_exclude_slot_metrics_when_specified(): 373 | # Given 374 | actual_intent = "intent1" 375 | actual_slots = [ 376 | {"text": "slot1_value", "entity": "entity1", "slot_name": "slot1"}, 377 | {"text": "slot2_value", "entity": "entity2", "slot_name": "slot2"}, 378 | ] 379 | predicted_intent = actual_intent 380 | predicted_slots = [ 381 | { 382 | "rawValue": "slot1_value", 383 | "value": {"kind": "Custom", "value": "slot1_value"}, 384 | "range": {"start": 21, "end": 32}, 385 | "entity": "entity1", 386 | "slotName": "slot1", 387 | } 388 | ] 389 | 390 | # When 391 | include_slot_metrics = False 392 | metrics = compute_utterance_metrics( 393 | predicted_intent, 394 | predicted_slots, 395 | actual_intent, 396 | actual_slots, 397 | include_slot_metrics, 398 | exact_match, 399 | ) 400 | # Then 401 | expected_metrics = { 402 | "intent1": { 403 | "intent": {"false_negative": 0, "false_positive": 0, "true_positive": 1} 404 | } 405 | } 406 | 407 | assert expected_metrics == metrics 408 | 409 | 410 | def test_should_use_slot_matching_lambda_to_compute_metrics(): 411 | # Given 412 | actual_intent = "intent1" 413 | actual_slots = [ 414 | {"text": "slot1_value2", "entity": "entity1", "slot_name": "slot1"}, 415 | {"text": "slot2_value", "entity": "entity2", "slot_name": "slot2"}, 416 | ] 417 | predicted_intent = actual_intent 418 | predicted_slots = [ 419 | { 420 | "rawValue": "slot1_value", 421 | "value": {"kind": "Custom", "value": "slot1_value"}, 422 | "range": {"start": 21, "end": 32}, 423 | "entity": "entity1", 424 | "slotName": "slot1", 425 | } 426 | ] 427 | 428 | def slot_matching_lambda(l, r): 429 | return l[TEXT].split("_")[0] == r["rawValue"].split("_")[0] 430 | 431 | # When 432 | metrics = compute_utterance_metrics( 433 | predicted_intent, 434 | predicted_slots, 435 | actual_intent, 436 | actual_slots, 437 | True, 438 | slot_matching_lambda, 439 | ) 440 | # Then 441 | expected_metrics = { 442 | "intent1": { 443 | "intent": {"false_negative": 0, "false_positive": 0, "true_positive": 1}, 444 | "slots": { 445 | "slot1": {"false_negative": 0, "false_positive": 0, "true_positive": 1}, 446 | "slot2": {"false_negative": 1, "false_positive": 0, "true_positive": 0}, 447 | }, 448 | } 449 | } 450 | 451 | assert expected_metrics == metrics 452 | 453 | 454 | def test_aggregate_utils_should_work(): 455 | # Given 456 | lhs_metrics = { 457 | "intent1": { 458 | "exact_parsings": 2, 459 | "intent": {"false_positive": 4, "true_positive": 6, "false_negative": 9}, 460 | "slots": { 461 | "slot1": {"false_positive": 1, "true_positive": 2, "false_negative": 3}, 462 | }, 463 | }, 464 | "intent2": { 465 | "exact_parsings": 1, 466 | "intent": {"false_positive": 3, "true_positive": 2, "false_negative": 5}, 467 | "slots": { 468 | "slot2": {"false_positive": 4, "true_positive": 2, "false_negative": 2}, 469 | }, 470 | }, 471 | } 472 | 473 | rhs_metrics = { 474 | "intent1": { 475 | "exact_parsings": 3, 476 | "intent": {"false_positive": 3, "true_positive": 3, "false_negative": 3}, 477 | "slots": { 478 | "slot1": {"false_positive": 2, "true_positive": 3, "false_negative": 1}, 479 | }, 480 | }, 481 | "intent2": { 482 | "exact_parsings": 5, 483 | "intent": {"false_positive": 4, "true_positive": 5, "false_negative": 6}, 484 | "slots": {}, 485 | }, 486 | "intent3": { 487 | "exact_parsings": 0, 488 | "intent": {"false_positive": 1, "true_positive": 7, "false_negative": 2}, 489 | "slots": {}, 490 | }, 491 | } 492 | 493 | # When 494 | aggregated_metrics = aggregate_metrics(lhs_metrics, rhs_metrics, True) 495 | 496 | # Then 497 | expected_metrics = { 498 | "intent1": { 499 | "exact_parsings": 5, 500 | "intent": {"false_positive": 7, "true_positive": 9, "false_negative": 12,}, 501 | "slots": { 502 | "slot1": {"false_positive": 3, "true_positive": 5, "false_negative": 4}, 503 | }, 504 | }, 505 | "intent2": { 506 | "exact_parsings": 6, 507 | "intent": {"false_positive": 7, "true_positive": 7, "false_negative": 11,}, 508 | "slots": { 509 | "slot2": {"false_positive": 4, "true_positive": 2, "false_negative": 2}, 510 | }, 511 | }, 512 | "intent3": { 513 | "exact_parsings": 0, 514 | "intent": {"false_positive": 1, "true_positive": 7, "false_negative": 2}, 515 | "slots": {}, 516 | }, 517 | } 518 | 519 | assert expected_metrics == aggregated_metrics 520 | 521 | 522 | def test_should_compute_precision_and_recall_and_f1(): 523 | # Given 524 | metrics = { 525 | "intent1": { 526 | "intent": {"false_positive": 7, "true_positive": 9, "false_negative": 12,}, 527 | "slots": { 528 | "slot1": {"false_positive": 3, "true_positive": 5, "false_negative": 4}, 529 | }, 530 | }, 531 | "intent2": { 532 | "intent": {"false_positive": 7, "true_positive": 7, "false_negative": 11,}, 533 | "slots": { 534 | "slot2": {"false_positive": 4, "true_positive": 2, "false_negative": 2}, 535 | }, 536 | }, 537 | } 538 | 539 | # When 540 | augmented_metrics = compute_precision_recall_f1(metrics) 541 | 542 | # Then 543 | expected_metrics = { 544 | "intent1": { 545 | "intent": { 546 | "false_positive": 7, 547 | "true_positive": 9, 548 | "false_negative": 12, 549 | "precision": 9.0 / (7.0 + 9.0), 550 | "recall": 9.0 / (12.0 + 9.0), 551 | "f1": 2 552 | * (9.0 / (7.0 + 9.0)) 553 | * (9.0 / (12.0 + 9.0)) 554 | / (9.0 / (7.0 + 9.0) + 9.0 / (12.0 + 9.0)), 555 | }, 556 | "slots": { 557 | "slot1": { 558 | "false_positive": 3, 559 | "true_positive": 5, 560 | "false_negative": 4, 561 | "precision": 5.0 / (5.0 + 3.0), 562 | "recall": 5.0 / (5.0 + 4.0), 563 | "f1": 2 564 | * (5.0 / (5.0 + 3.0)) 565 | * (5.0 / (5.0 + 4.0)) 566 | / (5.0 / (5.0 + 3.0) + 5.0 / (5.0 + 4.0)), 567 | }, 568 | }, 569 | }, 570 | "intent2": { 571 | "intent": { 572 | "false_positive": 7, 573 | "true_positive": 7, 574 | "false_negative": 11, 575 | "precision": 7.0 / (7.0 + 7.0), 576 | "recall": 7.0 / (7.0 + 11.0), 577 | "f1": 2 578 | * (7.0 / (7.0 + 7.0)) 579 | * (7.0 / (7.0 + 11.0)) 580 | / (7.0 / (7.0 + 7.0) + 7.0 / (7.0 + 11.0)), 581 | }, 582 | "slots": { 583 | "slot2": { 584 | "false_positive": 4, 585 | "true_positive": 2, 586 | "false_negative": 2, 587 | "precision": 2.0 / (2.0 + 4.0), 588 | "recall": 2.0 / (2.0 + 2.0), 589 | "f1": 2 590 | * (2.0 / (2.0 + 4.0)) 591 | * (2.0 / (2.0 + 2.0)) 592 | / (2.0 / (2.0 + 4.0) + 2.0 / (2.0 + 2.0)), 593 | }, 594 | }, 595 | }, 596 | } 597 | assert expected_metrics == augmented_metrics 598 | 599 | 600 | def test_contains_errors_should_work_when_errors_in_intent(): 601 | # Given 602 | utterance_metrics = { 603 | "intent1": { 604 | "intent": {TRUE_POSITIVE: 5, FALSE_POSITIVE: 1, FALSE_NEGATIVE: 0}, 605 | "slots": { 606 | "slot1": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0} 607 | }, 608 | }, 609 | "intent2": { 610 | "intent": {TRUE_POSITIVE: 20, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 611 | "slots": {}, 612 | }, 613 | } 614 | 615 | # When 616 | res = contains_errors(utterance_metrics, True) 617 | 618 | # Then 619 | assert res 620 | 621 | 622 | def test_contains_errors_should_work_when_errors_in_slots(): 623 | # Given 624 | utterance_metrics = { 625 | "intent1": { 626 | "intent": {TRUE_POSITIVE: 5, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 627 | "slots": { 628 | "slot1": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 629 | "slot2": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 2}, 630 | }, 631 | }, 632 | "intent2": { 633 | "intent": {TRUE_POSITIVE: 20, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 634 | "slots": {}, 635 | }, 636 | } 637 | 638 | # When 639 | res = contains_errors(utterance_metrics, True) 640 | 641 | # Then 642 | assert res 643 | 644 | 645 | def test_contains_errors_should_work_when_no_errors(): 646 | # Given 647 | utterance_metrics = { 648 | "intent1": { 649 | "intent": {TRUE_POSITIVE: 5, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 650 | "slots": { 651 | "slot1": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 652 | "slot2": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 653 | }, 654 | }, 655 | "intent2": { 656 | "intent": {TRUE_POSITIVE: 20, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 657 | "slots": {}, 658 | }, 659 | } 660 | 661 | # When 662 | res = contains_errors(utterance_metrics, True) 663 | 664 | # Then 665 | assert not res 666 | 667 | 668 | def test_contains_errors_should_not_check_slots_when_specified(): 669 | # Given 670 | utterance_metrics = { 671 | "intent1": { 672 | "intent": {TRUE_POSITIVE: 5, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 673 | "slots": { 674 | "slot1": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0}, 675 | "slot2": {TRUE_POSITIVE: 3, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 2}, 676 | }, 677 | } 678 | } 679 | 680 | # When 681 | check_slots = False 682 | res = contains_errors(utterance_metrics, check_slots) 683 | 684 | # Then 685 | assert not res 686 | 687 | 688 | def test_aggregate_matrix(): 689 | # Given 690 | lhs_confusion_matrix = { 691 | "intents": ["intent1", "intent2", "intent3"], 692 | "matrix": [[1, 10, 5], [3, 0, 4], [7, 8, 1]], 693 | } 694 | 695 | rhs_confusion_matrix = { 696 | "intents": ["intent1", "intent2", "intent3"], 697 | "matrix": [[3, 3, 3], [2, 7, 1], [9, 0, 1]], 698 | } 699 | 700 | # When 701 | acc_matrix = aggregate_matrices(lhs_confusion_matrix, rhs_confusion_matrix) 702 | 703 | # Then 704 | expected_confusion_matrix = { 705 | "intents": ["intent1", "intent2", "intent3"], 706 | "matrix": [[4, 13, 8], [5, 7, 5], [16, 8, 2]], 707 | } 708 | 709 | assert expected_confusion_matrix == acc_matrix 710 | 711 | 712 | def test_should_create_splits_when_enough_data(): 713 | # Given 714 | dataset = { 715 | "intents": { 716 | "intents_1": {"utterances": 10 * [{"data": [{"text": "foobar"}]}]}, 717 | "intents_2": {"utterances": 12 * [{"data": [{"text": "foobar"}]}]}, 718 | }, 719 | "entities": dict(), 720 | "language": "en", 721 | } 722 | 723 | # When / Then 724 | create_shuffle_stratified_splits(dataset=dataset, n_splits=5, train_size_ratio=0.5) 725 | 726 | 727 | def test_should_not_create_splits_when_not_enough_data(): 728 | # Given 729 | dataset = { 730 | "intents": { 731 | "intents_1": {"utterances": 10 * [{"data": [{"text": "foobar"}]}]}, 732 | "intents_2": {"utterances": 12 * [{"data": [{"text": "foobar"}]}]}, 733 | }, 734 | "entities": dict(), 735 | "language": "en", 736 | } 737 | 738 | # When / Then 739 | with pytest.raises(NotEnoughDataError): 740 | create_shuffle_stratified_splits( 741 | dataset=dataset, n_splits=6, train_size_ratio=0.5 742 | ) 743 | -------------------------------------------------------------------------------- /snips_nlu_metrics/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snipsco/snips-nlu-metrics/34a6d0baea9f1adc739423249762be62ba4e27cc/snips_nlu_metrics/utils/__init__.py -------------------------------------------------------------------------------- /snips_nlu_metrics/utils/constants.py: -------------------------------------------------------------------------------- 1 | INTENT = "intent" 2 | INTENTS = "intents" 3 | ENTITIES = "entities" 4 | UTTERANCES = "utterances" 5 | USE_SYNONYMS = "use_synonyms" 6 | SYNONYMS = "synonyms" 7 | DATA = "data" 8 | VALUE = "value" 9 | TEXT = "text" 10 | ENTITY = "entity" 11 | SLOT_NAME = "slot_name" 12 | TRUE_POSITIVE = "true_positive" 13 | FALSE_POSITIVE = "false_positive" 14 | FALSE_NEGATIVE = "false_negative" 15 | INTENT_UTTERANCES = "intent_utterances" 16 | PARSING_ERRORS = "parsing_errors" 17 | METRICS = "metrics" 18 | AVERAGE_METRICS = "average_metrics" 19 | CONFUSION_MATRIX = "confusion_matrix" 20 | NONE_INTENT_NAME = "null" 21 | EXACT_PARSINGS = "exact_parsings" 22 | -------------------------------------------------------------------------------- /snips_nlu_metrics/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from copy import deepcopy 3 | 4 | from snips_nlu_metrics.utils.constants import ( 5 | DATA, 6 | ENTITIES, 7 | ENTITY, 8 | INTENTS, 9 | SYNONYMS, 10 | TEXT, 11 | USE_SYNONYMS, 12 | UTTERANCES, 13 | VALUE, 14 | ) 15 | 16 | 17 | def input_string_from_chunks(chunks): 18 | return "".join(chunk[TEXT] for chunk in chunks) 19 | 20 | 21 | def get_utterances_subset(utterances, ratio): 22 | utterances_dict = dict() 23 | for (intent_name, utterance) in utterances: 24 | if intent_name not in utterances_dict: 25 | utterances_dict[intent_name] = [] 26 | utterances_dict[intent_name].append(deepcopy(utterance)) 27 | 28 | utterances_subset = [] 29 | for (intent_name, utterances) in utterances_dict.items(): 30 | nb_utterances = int(ratio * len(utterances)) 31 | utterances_subset += [(intent_name, u) for u in utterances[:nb_utterances]] 32 | return utterances_subset 33 | 34 | 35 | def is_builtin_entity(entity_name): 36 | return entity_name.startswith("snips/") 37 | 38 | 39 | def get_declared_entities_values(dataset): 40 | existing_entities = dict() 41 | for entity_name, entity in dataset[ENTITIES].items(): 42 | if is_builtin_entity(entity_name): 43 | continue 44 | ents = set() 45 | for data in entity[DATA]: 46 | ents.add(data[VALUE]) 47 | if entity[USE_SYNONYMS]: 48 | for s in data[SYNONYMS]: 49 | ents.add(s) 50 | existing_entities[entity_name] = ents 51 | return existing_entities 52 | 53 | 54 | def get_intent_utterances_entities_value(dataset): 55 | existing_entities = defaultdict(set) 56 | for intent in dataset[INTENTS].values(): 57 | for u in intent[UTTERANCES]: 58 | for chunk in u[DATA]: 59 | if ENTITY not in chunk or is_builtin_entity(chunk[ENTITY]): 60 | continue 61 | existing_entities[chunk[ENTITY]].add(chunk[TEXT]) 62 | return existing_entities 63 | 64 | 65 | def make_entity(value, synonyms): 66 | return {"value": value, "synonyms": synonyms} 67 | 68 | 69 | def update_entities_with_utterances(dataset): 70 | dataset = deepcopy(dataset) 71 | 72 | declared_entities = get_declared_entities_values(dataset) 73 | intent_entities = get_intent_utterances_entities_value(dataset) 74 | 75 | for entity_name, existing_entities in declared_entities.items(): 76 | for entity_value in intent_entities.get(entity_name, []): 77 | if entity_value not in existing_entities: 78 | dataset[ENTITIES][entity_name][DATA].append( 79 | make_entity(entity_value, []) 80 | ) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /snips_nlu_metrics/utils/exception.py: -------------------------------------------------------------------------------- 1 | from snips_nlu_metrics.utils.constants import UTTERANCES, INTENTS 2 | 3 | 4 | class NotEnoughDataError(Exception): 5 | def __init__(self, dataset, nb_folds, train_size_ratio): 6 | self.dataset = dataset 7 | self.nb_folds = nb_folds 8 | self.train_size_ratio = train_size_ratio 9 | self.intent_utterances = { 10 | intent: len(data[UTTERANCES]) for intent, data in dataset[INTENTS].items() 11 | } 12 | 13 | @property 14 | def message(self): 15 | return "Not enough data: %r" % self 16 | 17 | def __repr__(self): 18 | return ", ".join( 19 | [ 20 | "nb folds = %s" % self.nb_folds, 21 | "train size ratio = %s" % self.train_size_ratio, 22 | "intents details = [%s]" 23 | % ", ".join( 24 | "%s -> %d utterances" % (intent, nb) 25 | for intent, nb in sorted(self.intent_utterances.items()) 26 | ), 27 | ] 28 | ) 29 | 30 | def __str__(self): 31 | return repr(self) 32 | -------------------------------------------------------------------------------- /snips_nlu_metrics/utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | import sys 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | from sklearn.model_selection import StratifiedKFold 8 | from sklearn.utils import check_random_state 9 | 10 | from snips_nlu_metrics.utils.constants import ( 11 | DATA, 12 | ENTITIES, 13 | ENTITY, 14 | FALSE_NEGATIVE, 15 | FALSE_POSITIVE, 16 | INTENTS, 17 | NONE_INTENT_NAME, 18 | SLOT_NAME, 19 | TEXT, 20 | TRUE_POSITIVE, 21 | UTTERANCES, 22 | EXACT_PARSINGS, 23 | ) 24 | from snips_nlu_metrics.utils.dataset_utils import ( 25 | get_utterances_subset, 26 | input_string_from_chunks, 27 | update_entities_with_utterances, 28 | ) 29 | from snips_nlu_metrics.utils.exception import NotEnoughDataError 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | INITIAL_METRICS = {TRUE_POSITIVE: 0, FALSE_POSITIVE: 0, FALSE_NEGATIVE: 0} 34 | 35 | 36 | def create_shuffle_stratified_splits( 37 | dataset, 38 | n_splits, 39 | train_size_ratio=1.0, 40 | drop_entities=False, 41 | seed=None, 42 | out_of_domain_utterances=None, 43 | intents_filter=None, 44 | ): 45 | if train_size_ratio > 1.0 or train_size_ratio < 0: 46 | error_msg = "Invalid value for train size ratio: %s" % train_size_ratio 47 | logger.exception(error_msg) 48 | raise ValueError(error_msg) 49 | 50 | nb_utterances = { 51 | intent: len(data[UTTERANCES]) for intent, data in dataset[INTENTS].items() 52 | } 53 | if any((nb * train_size_ratio < n_splits for nb in nb_utterances.values())): 54 | raise NotEnoughDataError(dataset, n_splits, train_size_ratio) 55 | 56 | if drop_entities: 57 | dataset = deepcopy(dataset) 58 | for entity, data in dataset[ENTITIES].items(): 59 | data[DATA] = [] 60 | else: 61 | dataset = update_entities_with_utterances(dataset) 62 | 63 | utterances = np.array( 64 | [ 65 | (intent_name, utterance) 66 | for intent_name, intent_data in dataset[INTENTS].items() 67 | for utterance in intent_data[UTTERANCES] 68 | ] 69 | ) 70 | intents = np.array([u[0] for u in utterances]) 71 | X = np.zeros(len(intents)) 72 | random_state = check_random_state(seed) 73 | sss = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) 74 | splits = [] 75 | for train_index, test_index in sss.split(X, intents): 76 | train_utterances = utterances[train_index].tolist() 77 | train_utterances = get_utterances_subset(train_utterances, train_size_ratio) 78 | test_utterances = utterances[test_index].tolist() 79 | train_dataset = deepcopy(dataset) 80 | train_dataset[INTENTS] = dict() 81 | for intent_name, utterance in train_utterances: 82 | if intent_name not in train_dataset[INTENTS]: 83 | train_dataset[INTENTS][intent_name] = {UTTERANCES: []} 84 | train_dataset[INTENTS][intent_name][UTTERANCES].append(deepcopy(utterance)) 85 | splits.append((train_dataset, test_utterances)) 86 | 87 | if intents_filter is not None: 88 | filtered_splits = [] 89 | for train_dataset, test_utterances in splits: 90 | test_utterances = [ 91 | (intent_name, utterance) 92 | for intent_name, utterance in test_utterances 93 | if intent_name in intents_filter 94 | ] 95 | filtered_splits.append((train_dataset, test_utterances)) 96 | splits = filtered_splits 97 | 98 | if out_of_domain_utterances is not None: 99 | additional_test_utterances = [ 100 | [NONE_INTENT_NAME, {DATA: [{TEXT: utterance}]}] 101 | for utterance in out_of_domain_utterances 102 | ] 103 | for split in splits: 104 | split[1].extend(additional_test_utterances) 105 | 106 | return splits 107 | 108 | 109 | def compute_split_metrics( 110 | engine_class, 111 | split, 112 | intent_list, 113 | include_slot_metrics, 114 | slot_matching_lambda, 115 | intents_filter, 116 | ): 117 | """Fit and run engine on a split specified by train_dataset and 118 | test_utterances""" 119 | train_dataset, test_utterances = split 120 | engine = engine_class() 121 | engine.fit(train_dataset) 122 | return compute_engine_metrics( 123 | engine, 124 | test_utterances, 125 | intent_list, 126 | include_slot_metrics, 127 | slot_matching_lambda, 128 | intents_filter, 129 | ) 130 | 131 | 132 | def compute_engine_metrics( 133 | engine, 134 | test_utterances, 135 | intent_list, 136 | include_slot_metrics, 137 | slot_matching_lambda=None, 138 | intents_filter=None, 139 | ): 140 | if slot_matching_lambda is None: 141 | slot_matching_lambda = exact_match 142 | metrics = dict() 143 | intent_list = intent_list + [NONE_INTENT_NAME] 144 | confusion_matrix = dict( 145 | intents=intent_list, 146 | matrix=[[0 for _ in range(len(intent_list))] for _ in range(len(intent_list))], 147 | ) 148 | intents_idx = {intent_name: idx for idx, intent_name in enumerate(intent_list)} 149 | 150 | errors = [] 151 | for actual_intent, utterance in test_utterances: 152 | actual_slots = [chunk for chunk in utterance[DATA] if SLOT_NAME in chunk] 153 | input_string = input_string_from_chunks(utterance[DATA]) 154 | if has_filter_param(engine): 155 | parsing = engine.parse(input_string, intents_filter=intents_filter) 156 | else: 157 | if intents_filter: 158 | logger.warning( 159 | "The provided NLU engine (%r) does not support " 160 | "intents filter through its `parse` API, " 161 | "however one has been passed (%s)", 162 | engine, 163 | intents_filter, 164 | ) 165 | parsing = engine.parse(input_string) 166 | 167 | if parsing["intent"] is not None: 168 | predicted_intent = parsing["intent"]["intentName"] 169 | if predicted_intent is None: 170 | predicted_intent = NONE_INTENT_NAME 171 | else: 172 | # Use a string here to avoid having a None key in the metrics dict 173 | predicted_intent = NONE_INTENT_NAME 174 | 175 | predicted_slots = [] if parsing["slots"] is None else parsing["slots"] 176 | 177 | i = intents_idx.get(actual_intent) 178 | j = intents_idx.get(predicted_intent) 179 | 180 | if i is None or j is None: 181 | continue 182 | 183 | confusion_matrix["matrix"][i][j] += 1 184 | 185 | utterance_metrics = compute_utterance_metrics( 186 | predicted_intent, 187 | predicted_slots, 188 | actual_intent, 189 | actual_slots, 190 | include_slot_metrics, 191 | slot_matching_lambda, 192 | ) 193 | for intent in utterance_metrics: 194 | utterance_metrics[intent][EXACT_PARSINGS] = 0 195 | if contains_errors(utterance_metrics, include_slot_metrics): 196 | if not include_slot_metrics: 197 | parsing.pop("slots") 198 | errors.append( 199 | { 200 | "nlu_output": parsing, 201 | "expected_output": format_expected_output( 202 | actual_intent, utterance, include_slot_metrics 203 | ), 204 | } 205 | ) 206 | else: 207 | utterance_metrics[actual_intent][EXACT_PARSINGS] = 1 208 | metrics = aggregate_metrics(metrics, utterance_metrics, include_slot_metrics) 209 | return metrics, errors, confusion_matrix 210 | 211 | 212 | def has_filter_param(engine): 213 | if sys.version_info[0] == 2: 214 | parse_args = inspect.getargspec(engine.parse).args 215 | else: 216 | parse_args = inspect.signature(engine.parse).parameters 217 | return "intents_filter" in parse_args 218 | 219 | 220 | def compute_utterance_metrics( 221 | predicted_intent, 222 | predicted_slots, 223 | actual_intent, 224 | actual_slots, 225 | include_slot_metrics, 226 | slot_matching_lambda, 227 | ): 228 | # initialize metrics 229 | intent_names = {predicted_intent, actual_intent} 230 | slot_names = set( 231 | [(predicted_intent, s["slotName"]) for s in predicted_slots] 232 | + [(actual_intent, u[SLOT_NAME]) for u in actual_slots] 233 | ) 234 | 235 | metrics = dict() 236 | for intent in intent_names: 237 | metrics[intent] = {"intent": deepcopy(INITIAL_METRICS)} 238 | if include_slot_metrics: 239 | metrics[intent]["slots"] = dict() 240 | 241 | if include_slot_metrics: 242 | for (intent_name, slot_name) in slot_names: 243 | metrics[intent_name]["slots"][slot_name] = deepcopy(INITIAL_METRICS) 244 | 245 | if predicted_intent == actual_intent: 246 | metrics[predicted_intent]["intent"][TRUE_POSITIVE] += 1 247 | else: 248 | metrics[predicted_intent]["intent"][FALSE_POSITIVE] += 1 249 | metrics[actual_intent]["intent"][FALSE_NEGATIVE] += 1 250 | return metrics 251 | 252 | if not include_slot_metrics: 253 | return metrics 254 | 255 | # Check if expected slots have been parsed 256 | for slot in actual_slots: 257 | slot_name = slot[SLOT_NAME] 258 | slot_metrics = metrics[actual_intent]["slots"][slot_name] 259 | if any( 260 | s["slotName"] == slot_name and slot_matching_lambda(slot, s) 261 | for s in predicted_slots 262 | ): 263 | slot_metrics[TRUE_POSITIVE] += 1 264 | else: 265 | slot_metrics[FALSE_NEGATIVE] += 1 266 | 267 | # Check if there are unexpected parsed slots 268 | for slot in predicted_slots: 269 | slot_name = slot["slotName"] 270 | slot_metrics = metrics[predicted_intent]["slots"][slot_name] 271 | if all( 272 | s[SLOT_NAME] != slot_name or not slot_matching_lambda(s, slot) 273 | for s in actual_slots 274 | ): 275 | slot_metrics[FALSE_POSITIVE] += 1 276 | return metrics 277 | 278 | 279 | def aggregate_metrics(lhs_metrics, rhs_metrics, include_slot_metrics): 280 | acc_metrics = deepcopy(lhs_metrics) 281 | for (intent, intent_metrics) in rhs_metrics.items(): 282 | if intent not in acc_metrics: 283 | acc_metrics[intent] = deepcopy(intent_metrics) 284 | else: 285 | acc_metrics[intent]["intent"] = add_count_metrics( 286 | acc_metrics[intent]["intent"], intent_metrics["intent"] 287 | ) 288 | acc_metrics[intent][EXACT_PARSINGS] += intent_metrics[EXACT_PARSINGS] 289 | if not include_slot_metrics: 290 | continue 291 | acc_slot_metrics = acc_metrics[intent]["slots"] 292 | for (slot, slot_metrics) in intent_metrics["slots"].items(): 293 | if slot not in acc_slot_metrics: 294 | acc_slot_metrics[slot] = deepcopy(slot_metrics) 295 | else: 296 | acc_slot_metrics[slot] = add_count_metrics( 297 | acc_slot_metrics[slot], slot_metrics 298 | ) 299 | return acc_metrics 300 | 301 | 302 | def aggregate_matrices(lhs_matrix, rhs_matrix): 303 | if lhs_matrix is None: 304 | return rhs_matrix 305 | if rhs_matrix is None: 306 | return lhs_matrix 307 | acc_matrix = deepcopy(lhs_matrix) 308 | matrix_size = len(acc_matrix["matrix"]) 309 | for i in range(matrix_size): 310 | for j in range(matrix_size): 311 | acc_matrix["matrix"][i][j] += rhs_matrix["matrix"][i][j] 312 | return acc_matrix 313 | 314 | 315 | def add_count_metrics(lhs, rhs): 316 | return { 317 | TRUE_POSITIVE: lhs[TRUE_POSITIVE] + rhs[TRUE_POSITIVE], 318 | FALSE_POSITIVE: lhs[FALSE_POSITIVE] + rhs[FALSE_POSITIVE], 319 | FALSE_NEGATIVE: lhs[FALSE_NEGATIVE] + rhs[FALSE_NEGATIVE], 320 | } 321 | 322 | 323 | def compute_average_metrics(metrics, ignore_none_intent=True): 324 | metrics = deepcopy(metrics) 325 | if ignore_none_intent: 326 | metrics = { 327 | intent: intent_metrics 328 | for intent, intent_metrics in metrics.items() 329 | if intent and intent != NONE_INTENT_NAME 330 | } 331 | 332 | nb_intents = len(metrics) 333 | if not nb_intents: 334 | return None 335 | 336 | average_intent_f1 = ( 337 | sum( 338 | intent_metrics["intent"]["f1"] for intent, intent_metrics in metrics.items() 339 | ) 340 | / nb_intents 341 | ) 342 | average_intent_precision = ( 343 | sum( 344 | intent_metrics["intent"]["precision"] 345 | for intent, intent_metrics in metrics.items() 346 | ) 347 | / nb_intents 348 | ) 349 | average_intent_recall = ( 350 | sum( 351 | intent_metrics["intent"]["recall"] 352 | for intent, intent_metrics in metrics.items() 353 | ) 354 | / nb_intents 355 | ) 356 | 357 | average_metrics = { 358 | "intent": { 359 | "f1": average_intent_f1, 360 | "precision": average_intent_precision, 361 | "recall": average_intent_recall, 362 | }, 363 | } 364 | 365 | nb_slots = sum( 366 | 1 367 | for intent_metrics in metrics.values() 368 | for _ in intent_metrics.get("slots", dict()).values() 369 | ) 370 | if nb_slots == 0: 371 | return average_metrics 372 | 373 | average_slot_f1 = ( 374 | sum( 375 | slot_metrics["f1"] 376 | for intent_metrics in metrics.values() 377 | for slot_metrics in intent_metrics["slots"].values() 378 | ) 379 | / nb_slots 380 | ) 381 | average_slot_precision = ( 382 | sum( 383 | slot_metrics["precision"] 384 | for intent_metrics in metrics.values() 385 | for slot_metrics in intent_metrics["slots"].values() 386 | ) 387 | / nb_slots 388 | ) 389 | average_slot_recall = ( 390 | sum( 391 | slot_metrics["recall"] 392 | for intent_metrics in metrics.values() 393 | for slot_metrics in intent_metrics["slots"].values() 394 | ) 395 | / nb_slots 396 | ) 397 | 398 | average_metrics["slot"] = { 399 | "f1": average_slot_f1, 400 | "precision": average_slot_precision, 401 | "recall": average_slot_recall, 402 | } 403 | return average_metrics 404 | 405 | 406 | def compute_precision_recall_f1(metrics): 407 | for intent_metrics in metrics.values(): 408 | prec_rec_metrics = _compute_precision_recall_f1(intent_metrics["intent"]) 409 | intent_metrics["intent"].update(prec_rec_metrics) 410 | if "slots" in intent_metrics: 411 | for slot_metrics in intent_metrics["slots"].values(): 412 | prec_rec_metrics = _compute_precision_recall_f1(slot_metrics) 413 | slot_metrics.update(prec_rec_metrics) 414 | return metrics 415 | 416 | 417 | def _compute_precision_recall_f1(count_metrics): 418 | tp = count_metrics[TRUE_POSITIVE] 419 | fp = count_metrics[FALSE_POSITIVE] 420 | fn = count_metrics[FALSE_NEGATIVE] 421 | precision = 0.0 if tp == 0 else float(tp) / float(tp + fp) 422 | recall = 0.0 if tp == 0 else float(tp) / float(tp + fn) 423 | if precision == 0.0 or recall == 0.0: 424 | f1 = 0.0 425 | else: 426 | f1 = 2 * (precision * recall) / (precision + recall) 427 | return {"precision": precision, "recall": recall, "f1": f1} 428 | 429 | 430 | def contains_errors(utterance_metrics, check_slots): 431 | for metrics in utterance_metrics.values(): 432 | intent_metrics = metrics["intent"] 433 | if intent_metrics.get(FALSE_POSITIVE, 0) > 0: 434 | return True 435 | if intent_metrics.get(FALSE_NEGATIVE, 0) > 0: 436 | return True 437 | if not check_slots: 438 | continue 439 | for slot_metrics in metrics["slots"].values(): 440 | if slot_metrics.get(FALSE_POSITIVE, 0) > 0: 441 | return True 442 | if slot_metrics.get(FALSE_NEGATIVE, 0) > 0: 443 | return True 444 | return False 445 | 446 | 447 | def format_expected_output(intent_name, utterance, include_slots): 448 | char_index = 0 449 | ranges = [] 450 | for chunk in utterance[DATA]: 451 | range_end = char_index + len(chunk[TEXT]) 452 | ranges.append({"start": char_index, "end": range_end}) 453 | char_index = range_end 454 | 455 | expected_output = { 456 | "input": "".join(chunk[TEXT] for chunk in utterance[DATA]), 457 | "intent": {"intentName": intent_name, "probability": 1.0}, 458 | } 459 | if include_slots: 460 | expected_output["slots"] = [ 461 | { 462 | "rawValue": chunk[TEXT], 463 | "entity": chunk[ENTITY], 464 | "slotName": chunk[SLOT_NAME], 465 | "range": ranges[chunk_index], 466 | } 467 | for chunk_index, chunk in enumerate(utterance[DATA]) 468 | if ENTITY in chunk 469 | ] 470 | return expected_output 471 | 472 | 473 | def exact_match(lhs_slot, rhs_slot): 474 | return lhs_slot[TEXT] == rhs_slot["rawValue"] 475 | -------------------------------------------------------------------------------- /snips_nlu_metrics/utils/temp_utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from tempfile import mkdtemp 3 | 4 | 5 | class tempdir_ctx: 6 | def __init__(self, suffix="", prefix="tmp", dir=None): 7 | self.suffix = suffix 8 | self.prefix = prefix 9 | self.dir = dir 10 | 11 | def __enter__(self): 12 | self.engine_dir = mkdtemp(suffix=self.suffix, prefix=self.prefix, dir=self.dir) 13 | return self.engine_dir 14 | 15 | def __exit__(self, exc_type, exc_val, exc_tb): 16 | shutil.rmtree(self.engine_dir) 17 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py35, py36, py37, py38 3 | 4 | [testenv] 5 | commands = 6 | pip install --upgrade pip 7 | pip install -e '.[test]' 8 | pytest --cache-clear --durations=0 9 | --------------------------------------------------------------------------------