├── .dockerignore ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── SECURITY.md ├── doc ├── dev.md └── formats.md ├── lmchallenge ├── __init__.py ├── __main__.py ├── core │ ├── __init__.py │ ├── common.py │ ├── errors.py │ ├── model.py │ ├── reranking.py │ └── tests │ │ ├── __init__.py │ │ ├── test_common.py │ │ ├── test_errors.py │ │ └── test_performance.py ├── data │ ├── viewer.css │ ├── viewer.html │ └── viewer.js ├── diff.py ├── grep.py ├── log.schema ├── pretty.py ├── run.py ├── stats.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── eg_models.py │ ├── test_functional.py │ ├── test_grep.py │ ├── test_pretty.py │ └── test_validate.py └── validate.py ├── requirements-base.txt ├── requirements-dev.txt ├── requirements.txt ├── sample ├── .gitignore ├── ngram.py ├── prepare.sh └── prepare_big.sh ├── scripts ├── Dockerfile.base ├── Dockerfile.dev ├── Dockerfile.notebook └── run ├── setup.cfg ├── setup.py └── version.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | # Exclude by default 2 | 3 | * 4 | 5 | # Include specific files & folders 6 | 7 | !lmchallenge/* 8 | !README.md 9 | !requirements* 10 | !setup.py 11 | !setup.cfg 12 | !version.txt 13 | 14 | # Re-exclude some files 15 | 16 | *.pyc 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # pytest cache 104 | .pytest_cache/ 105 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: sh 2 | sudo: required 3 | services: [docker] 4 | branches: {only: [master]} 5 | 6 | before_install: 7 | - ./scripts/run build 8 | 9 | script: 10 | - ./scripts/run check 11 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.5 2 | 3 | ENV LC_ALL=C.UTF-8 \ 4 | LANG=C.UTF-8 5 | 6 | ADD . /tmp/lmc 7 | RUN cd /tmp/lmc/ \ 8 | && python3 setup.py install \ 9 | && rm -r /tmp/lmc 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | LMChallenge 2 | 3 | MIT License 4 | 5 | Copyright (c) Microsoft Corporation. All rights reserved. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE 24 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include version.txt README.md requirements.txt LICENSE 2 | recursive-include lmchallenge *.css *.html *.js *.schema 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language Model Challenge (LM Challenge) 2 | 3 | _A library & tools to evaluate predictive language models._ This is a guide for users of LM Challenge; you may also want to see: 4 | 5 | - [data formats](doc/formats.md) for integrators 6 | - [dev notes](doc/dev.md) for developers wishing to extend LM Challenge 7 | 8 | 9 | ## What is LM Challenge for? 10 | 11 | It is hard to compare language model performance in general. Some models output probabilities, others scores; some model words, others morphemes, characters or bytes. Vocabulary coverage varies. Comparing them in a fair way is therefore difficult. So in LM Challenge we have some very simple 'challenge games' that evaluate (and help compare) language models over a test corpus. 12 | 13 | LM Challenge is for researchers and engineers who wish to set a standard for fair comparison of very different language model architectures. It requires a little work to wrap your model in a standard API, but we believe this is often better than writing & testing evaluation tools afresh for each project/investigation. 14 | 15 | Note: most of LM Challenge tools are word-based (although all can be applied to sub-word "character compositional" word models). Additionally, our assumption is that the language model is "forward contextual" - so it predicts a word or character based only on preceding words/characters. 16 | 17 | 18 | ## Getting Started 19 | 20 | Install LM Challenge from the published Python package: 21 | 22 | pip3 install --user lmchallenge 23 | 24 | (Or from this repository `python3 setup.py install --user`.) 25 | 26 | **Setup:** LM Challenge needs a model to evaluate. We include an example ngram model implementation in `sample`. Download data & build models (this may take a couple of minutes): 27 | 28 | cd sample/ 29 | ./prepare.sh 30 | 31 | **Model REPL:** Now you can use the example script to evaluate a very basic ngram model (see [ngram.py](sample/ngram.py), which you may find useful if integrating your own prediction model). _Note that this command will not terminate, as it launches an interactive program:_ 32 | 33 | python3 ngram.py words data/words.3gram 34 | 35 | This starts an interactive program which can accept commands of a single word followed by a hard `TAB` character and any arguments, for example: 36 | 37 | > predict 38 | = 0.0000 The -1.0000 In -2.0000... 39 | 40 | This produces start-of-line predictions, each with an attached score. To query with word context, try the following (making sure you leave a trailing space at the end of the query, after "favourite"): 41 | 42 | > predictMy favourite 43 | of 0.0000 song -1.0000 the -2.0000... 44 | 45 | This provides next-word-prediction based on a context. There is more to the API (see [formats](doc/formats.md) for more details), but since you won't usually be using the API directly, let's move on to running LM Challenge over this model (so exit the predictor using `Ctrl+D`, back to your shell). 46 | 47 | **Evaluation:** To run LM Challenge for this model, we'll pipe some text into `lmc run`, and save the result: 48 | 49 | mkdir out 50 | head -n 10 data/wiki.test.tokens | lmc run "python3 ngram.py words data/words.3gram" wc > out/w3.wc.log 51 | 52 | The resulting log contains all of the original text, and can be queried using the `lmc` utilities. Note: `jq` here is optional, but a very convenient program for working with JSON. 53 | 54 | lmc stats out/w3.wc.log | jq . 55 | 56 | You should see some statistics about the model - in particular `completion` & `prediction`. Now let's try comparing with a less powerful model: 57 | 58 | head -n 10 data/wiki.test.tokens | lmc run "python3 ngram.py words data/words.2gram" wc > out/w2.wc.log 59 | lmc stats out/*.wc.log | jq . 60 | 61 | The aggregated level prediction and completion stats should be slightly different for the two models. But we can get a better picture from inspecting the logs in detail: 62 | 63 | lmc pretty out/w3.wc.log 64 | 65 | This shows a pretty-printed dump of the data, according to how well the model performed on each token. We can also pretty-print the difference between two models: 66 | 67 | lmc diff out/w3.wc.log out/w2.wc.log 68 | 69 | Filter the log for only capitalized words, and print summary statistics: 70 | 71 | lmc grep "^[A-Z][a-z]+$" out/w3.wc.log | lmc stats | jq . 72 | 73 | You should notice that capitalized words are (in this small, statistically insignificant example), much harder to predict than words in general. 74 | 75 | **Other challenges:** Other LM challenges can be run & inspected in a similar way, see `lmc run --help`. 76 | 77 | 78 | ## Running LM Challenge 79 | 80 | LM Challenge is quite flexible - it can be used in a variety of ways: 81 | 82 | 1. Command Line Interface 83 | 2. Python API 84 | 3. Log file format 85 | 86 | ### 1. Command Line Interface 87 | 88 | This is the simplest way of using LM Challenge, and works if your model is implemented in any language supporting piped stdout/stdin. See the [Getting Started](#getting-started) guide above, and the CLI help: 89 | 90 | lmc --help 91 | lmc run --help 92 | 93 | ### 2. Python API 94 | 95 | If your model runs in Python 3, and you wish to script evaluation in Python, you can use the API directly: 96 | 97 | import lmchallenge as lmc 98 | help(lmc) 99 | 100 | Our documentation (as in `help(lmc)`) includes a tutorial for getting started with Python. We don't yet publish the HTML, but it has been tested with `pdoc`: 101 | 102 | $ pdoc --http 103 | # use your browser to view generated documentation 104 | 105 | ### 3. Log file format 106 | 107 | If you require batching or distribution for sufficient evaluation speed, you can write the LM Challenge log files yourself. This means you can use LM Challenge to process & analyse the data, without imposing a particular execution model. To do this: 108 | 109 | 1. Write JSONlines files that contain lmchallenge log data: 110 | - See [data formats](doc/formats.md) notes that describe the log format. 111 | - (Optionally) use the [JSON schema](lmchallenge/log.schema) that formally describes an acceptable log datum. 112 | - (Optionally) use the CLI `lmc validate` (or Python API `lmchallenge.validate.validate`) to check that your log conforms to the schema. 113 | - Note that log files can often be concatenated if they were generated in parallel. 114 | 2. Use the lmchallenge tools to analyse the logs (everything except `lmc run`). 115 | 116 | 117 | ## The details 118 | 119 | An _LM challenge game_ is a runnable Python module that evaluates one or more _language models_ on some task, over some _test text_. 120 | 121 | The **challenge games** we have are: 122 | 123 | - `wc` - Word Completion Challenge - a Next Word Prediction / Completion task (generates Hit@N & completion ratios) 124 | - `we|ce` - Word|Character Entropy Challenges - a language probability distribution task (generates cross entropy given a defined vocabulary) 125 | - `wr` - Word Reranking Challenge - a correction task (generates accuracy) 126 | 127 | **Test text** is pure text data (as typed & understood by real actual humans!) LM Challenge does not define test text - we expect it to be provided. This is the other thing you need to decide on in order to evaluate a _language model_. 128 | 129 | A **language model** is an executable process that responds to commands from a _LM challenge game_ in a specific text format, usually comprising of a pre-trained model of the same language as the _test text_. 130 | 131 | ### Word Completion `wc` 132 | 133 | The Word Completion task scans through words in the test text, at each point querying the language model for next-word predictions & word completions. 134 | 135 | cat DATA | lmc run "PREDICTOR" wc > LOG 136 | 137 | The model should aim to predict the correct next word before other words (i.e. with as low a rank as possible), or failing that to predict it in the top two completions, given as short a typed prefix as possible. Statistics available from `wc` include: 138 | 139 | - next-word-prediction 140 | - `Hit@N` - ratio of correct predictions obtained with rank below `N` 141 | - `MRR` (Mean Reciprocal Rank) - the sum total of `1/rank` over all words 142 | - completion 143 | - `characters` - ratio of characters that were completed (e.g. if typing `"hello"`, and it is predicted after you type `"he"`, the ratio of completed characters would be `0.5`) 144 | - `tokens` - ratio of tokens that were completed before they were fully typed 145 | 146 | Note that the flag `--next-word-only` may be used to speed up evaluation, by skipping all prefixes, only evaluating the model's next-word-prediction performance (so that completion stats are not generated). 147 | 148 | ### Word/Character Entropy `we|ce` 149 | 150 | The Word/Character Entropy task produces stats that are analogous to the standard cross-entropy/perplexity measures used for evaluating language models. These evaluators scan through text, at each point querying the language model for a normalized log-probability for the current word. 151 | 152 | cat DATA | lmc run "PREDICTOR" we > LOG 153 | cat DATA | lmc run "PREDICTOR" ce > LOG 154 | 155 | It is important to note that the entropy metric can only be compared between models that share a common vocabulary. If the vocabulary is different, the entropy task is different, and models should not be compared. Therefore, a model must generate a "fair" normalized log-probability over its vocabulary (and if a word is not in the vocabulary, to omit the score from the results). It should not merge "equivalence classes" of words (except by general agreement with every other model being evaluated). An example of this would be example normalizing capitalization to give "fish" the same score as "Fish", or giving many words an "out of vocabulary" score (such that, if you were to calculate `p("fish") + p("Fish") + p(everything else)` it would not sum to one). Simply ommiting any words that cannot be scored (e.g. OOV words) is safe, as this contributes to a special "entropy fingerprint", which checks that two models successfully scored the same set of words, and are therefore comparable under the entropy metric. 156 | 157 | ### Word Reranking `wr` 158 | 159 | The Word Reranking task emulates a sloppy typist entering text, using the language model to correct input after it has been typed. This challenge requires a list of words to use as correction candidates for corrupted words (which should be a large set of valid words in the target language.) Text from the data source is first corrupted (as if by a sloppy typist). The corrupted text is fed into a search for nearby candidate words, which are scored according to the language model under evaluation. The evaluator measures corrected, un-corrected and mis-corrected results. 160 | 161 | cat DATA | lmc run "PREDICTOR" wr VOCAB > LOG 162 | 163 | The aim of the model is to assign high score to the correct word, and low score to all other words. We evaluate this by mixing the score from the language model with an _input score_ for each word (using a minimum score for words that are not scored by the lanugage model), then ranking based on that. If the top-ranked prediction is the correct word, this example was a success, otherwise it counts as a failure. The _input score_ is the log-probability of the particular corrupted text being produced from this word, in the same error model that was used to corrupt the true word. In order to be robust against different ranges of scores from language models, we optimize the _input_ and _language model_ mixing parameters before counting statistics (this is done automatically, but requires the optional dependency `scipy`). The accuracy aggregate measures the maximum proportion of correct top predictions, using the optimum mixing proportions. 164 | 165 | 166 | ## Contributing 167 | 168 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 169 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 170 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 171 | 172 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 173 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 174 | provided by the bot. You will only need to do this once across all repos using our CLA. 175 | 176 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 177 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 178 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 179 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /doc/dev.md: -------------------------------------------------------------------------------- 1 | # Developing 2 | 3 | ## Contributing 4 | 5 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 6 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 7 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 8 | 9 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 10 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 11 | provided by the bot. You will only need to do this once across all repos using our CLA. 12 | 13 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 14 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 15 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 16 | 17 | ## Developing & testing 18 | 19 | We recommend developing using Docker, although virtual env is also very good (these tools are useful to make sure that you're using a stable, reproducible, environment). The Python script `./scripts/run` is our main development tool which automates build & checking tasks. 20 | 21 | **Tests** and general checks are currently are run with: 22 | 23 | ./scripts/run build 24 | ./scripts/run check 25 | 26 | For quicker tests, while developing, try `./scripts/run test`. 27 | 28 | **Documentation** may be built in the `/site` folder using: 29 | 30 | ./scripts/run doc 31 | 32 | ## Publishing 33 | 34 | ### Configure gpg for signing the release 35 | 36 | - Install [gpg](https://gnupg.org) (e.g. `brew install gpg` on macOS) 37 | - Obtain the private key for the release and its password (ask around) 38 | - Import the key into your keyring: 39 | ```bash 40 | gpg --allow-secret-key-import --import private.key 41 | ``` 42 | - Ensure the key has been imported: 43 | ```bash 44 | gpg --list-keys 45 | ``` 46 | Expected output: 47 | ``` 48 | pub rsa2048 2017-10-06 [SC] 49 | EA7AC0CCA097C391C7AA61F109F7AFCBCB48AC15 50 | uid [ unknown] SwiftKey DL&NLP 51 | sub rsa2048 2017-10-06 [E] 52 | ``` 53 | 54 | The key fingerprint is `EA7A C0CC A097 C391 C7AA 61F1 09F7 AFCB CB48 AC15` 55 | 56 | ### Configure PyPi access 57 | 58 | In your `~/.pypirc` specify: 59 | 60 | ``` 61 | [distutils] 62 | index-servers = 63 | pypi 64 | testpypi 65 | 66 | [pypi] 67 | repository: https://upload.pypi.org/legacy/ 68 | username: USERNAME 69 | password: PASSWORD 70 | 71 | [testpypi] 72 | repository: https://test.pypi.org/legacy/ 73 | username: USERNAME 74 | password: PASSWORD 75 | ``` 76 | 77 | Install `twine` for the current user 78 | 79 | ```bash 80 | python3 -m pip install --user twine 81 | ``` 82 | 83 | ### Publish a new release 84 | 85 | 1. (optionally) update requirements `./scripts/run -i base build --no-cache && ./scripts/run -i base refreeze` 86 | 2. run the pre-publish checks `./scripts/run check` 87 | 3. check that you're happy with `version.txt` 88 | 4. `rm -rf dist || true` to cleanup previously created artefacts 89 | 5. `python3 setup.py sdist` to package a new release 90 | 6. `twine upload --sign --identity "swiftkey-deep@" -r testpypi dist/*` to upload the release to TEST PyPi server 91 | 7. Check the new release on test.pypi.org 92 | 8. You can download release files and verify the signature via 93 | ``` 94 | gpg --verify lmchallenge-$(cat version.txt).tar.gz.asc lmchallenge-$(cat version.txt).tar.gz 95 | ``` 96 | 9. `twine upload --sign --identity "swiftkey-deep@" -r pypi dist/*` to upload the release to MAIN PyPi server at https://pypi.org 97 | 9. `git push origin HEAD:refs/tags/$(cat version.txt)` to push the new release tag to Github 98 | 10. Go to Github and create a release for the tag you've just pushed 99 | 11. update, commit & push `version.txt` to start a new version 100 | -------------------------------------------------------------------------------- /doc/formats.md: -------------------------------------------------------------------------------- 1 | # LM Challenge formats and APIs 2 | 3 | 1. [Corpus format](#corpus-format) 4 | 2. [Model API](#model-api) 5 | 3. [Log formats](#log-formats) 6 | 7 | 8 | ## Corpus format 9 | 10 | Two corpus formats are supported. The first, _plain text_ is designed basic 'flat' evaluation without user-specific training and modelling. The second, _marked-up_ format is designed to provide metadata about the text being evaluated - in particular, a User ID and timestamp, which allows LM Challenge to evaluate user adaptation on sequences of text. 11 | 12 | In all cases, input data should be UTF-8 encoded. 13 | 14 | ### Plain text 15 | 16 | The plain text format is simply a line-by-line text format with newline (`U+000A`) as the only delimiter (which should be used to separate paragraphs of potentially unrelated text). 17 | 18 | For example: 19 | 20 | This is a line of text. Everything on this line is related. 21 | Now we have some more, unrelated text, OK. 22 | 23 | ### User marked-up text 24 | 25 | The marked-up text format is based on [jsonlines](http://jsonlines.org/), which is a sequence of newline (`U+000A`)-separated JSON objects. Each line is of the following form (where every key except `text` is optional). The lines should be first grouped by `userId`, then ordered by `timestamp` for each user. 26 | 27 | {"userId": ID, "timestamp": NUMBER, "text": STRING} 28 | 29 | For example: 30 | 31 | {"userId": "aaa", "timestamp": 100000, "text": "I'm happy today :-)"} 32 | {"userId": "aaa", "timestamp": 103000, "text": "Sad today :-("} 33 | {"userId": "bbb", "timestamp": 102000, "text": "Who do you think you are?"} 34 | 35 | 36 | ## Model API 37 | 38 | A language model is an executable process that responds to commands from a _LM challenge game_ in a specific text format, usually comprising of a pre-trained model of the same language as the _test text_. Generally this is specified as a shell command with predefined arguments. 39 | 40 | All text format APIs use UTF-8 encoding & only the newline (`U+000A`) & horizontal tab (`U+0009`) as delimiters (represented below as `` and ``). Care must be taken to ensure streams are flushed at appropriate points (typically after each newline) in order to avoid deadlock. 41 | 42 | #### `predict` 43 | 44 | `predict` is used to predict the next string of characters given 45 | context. The command specifies a context string (which is a prefix of 46 | a line from the test data). 47 | Optionally, the command also specifies a list of candidates that should 48 | be considered for prediction (if not specified, the model is itself 49 | responsible for generating valid candidates). 50 | Candidates may be multiple words, but should all correspond to the same 51 | amount of input (which should help make the resulting scores comparable). 52 | The language model responds with a list of prediction strings for the 53 | following characters, together with scores for each prediction. 54 | 55 | - Input is an untokenized context string (which may stop abruptly 56 | within a word, or may end in a space) 57 | - Output must be a list of next-string predictions (which may be words, 58 | characters, morphemes or phrases) with a score for each prediction. 59 | Format is `predictionscorepredictionscore` .... 60 | - The prediction should simply follow the characters of context (for 61 | example if the input is `"I like bi"` a prediction might be `"rds"` 62 | (as if completing the string `"I like birds"`), but not `"birds"` 63 | (which would be interpreted as suggesting `"I like bibirds"`). 64 | - Score is a number that's used to determine the ranking of your 65 | predictions; biggest score ranks first. The predictions need not be 66 | returned in rank order. 67 | - In general we make no further assumptions about what the score 68 | represents -- for example it could be a normalized predictive 69 | probability or log-probability, an unnormalized probability, the 70 | output of some non-probabilistic predictive model, or just the 71 | (reciprocal of the) predicted rank itself. 72 | - Some tools that operate on evaluation results may make additional 73 | assumptions about what model scores represent (e.g. that they are 74 | log-probabilities), but in these case the requirement will be 75 | documented. 76 | - If a specified candidate is unscorable in a model, it may be omitted 77 | from the results, in which case the treatment of that candidate is 78 | dependent on the evaluator. 79 | 80 | For example: 81 | 82 | predictI am your 83 | best0.2only0.1friend0.08Boss0.05 84 | 85 | predictI am your gu 86 | est-1.23ess-2.51y-2.82errilla-6.33 87 | 88 | predictI am your guestguerilla 89 | guest-1.23guerilla-6.33 90 | 91 | #### `train` 92 | 93 | `train` allows a model the opportunity to learn from a line of input, after having been evaluated on it. 94 | 95 | - Input: an untokenized line of text 96 | - Output: none (it is an error to send back even a newline in response to this command) 97 | 98 | For example: 99 | 100 | trainHey Rebecca, did you see that lacrosster? 101 | 102 | #### `clear` 103 | 104 | `clear` instructs the model to forget everything it has learnt from previous `train` calls. For example, this will be called when the dataset is changing (e.g. evaluating on data from a different user). 105 | 106 | - Input: none 107 | - Output: none (it is an error to send back even a newline in response to this command) 108 | 109 | For example: 110 | 111 | clear 112 | 113 | 114 | ## Log formats 115 | 116 | All LM Challenge games share a common log schema (defined formally in `log.schema`, which has common metadata and optional payload data for each challenge. Logs should be stored as UTF-8 encoded [jsonlines](http://jsonlines.org/), optionally gzipped. 117 | 118 | The required keys for a log event, which typically represents a single word or character from the data, are as follows: 119 | 120 | {"user": STRING, 121 | "character": NUMBER, 122 | "message": NUMBER, 123 | "token": NUMBER, 124 | "target": STRING} 125 | 126 | - `user` should be a unique identifier for that user (or `null`, if there is no user information) 127 | - `character` is the index of the start of the source data range, relative to the start of the message 128 | - `message` is the message index within a single user 129 | - `token` is the token index within a single message 130 | - `target` is a string from the source data, the text being modelled 131 | 132 | LM Challenge logs should be sorted by `(user, message, token)` - i.e. all events for a single user should be contiguous, and message & token should be in ascending order for that user. 133 | 134 | Note that LM Challenge logs contain the original source text, so should be subject to the same privacy constraints & data protection classification. 135 | 136 | ### `wc` logs 137 | 138 | Log lines from `wc` contain an additional key `completions`, which records next-word-predictions and prefix completions for the target word. 139 | 140 | {"completions": [[STRING, STRING, ...], 141 | [STRING, STRING, ...], 142 | ...]} 143 | 144 | Completions is a jagged 2D array of completions, such that `completions[i][j]` corresponds to the suffix predicted after typing `i` characters at prediction index `j`. For example, if the target is `"Hello"`, the completions array might be: 145 | 146 | [["Good", "Hi", "Are"], 147 | ["i", "ow", "e"], 148 | ["llo", "lp", "lpful"], 149 | ["lo", "p", "pful"], 150 | ["o", "enistic"]] 151 | 152 | I.e. the second row corresponds to the predictions `["Hi", "How", "He"]` (or, the whole word is reconstructed using `target[:i] + completions[i][j]`). 153 | 154 | If running `wc` in "fast" mode, only `completions[0][:]` is present - as this corresponds to zero characters of prefix, which is next-word-prediction. 155 | 156 | ### `we` & `ce` logs 157 | 158 | Log lines from `we` or `ce` contain an additional key `logp`, which records the log probability of this word or character target, or `null` if the target is not in the language model vocabulary. It is the responsibility of the model/evaluator to ensure that `logp` is normalized over the vocabulary (therefore it should, in general, be negative). 159 | 160 | ### `wr` logs 161 | 162 | Log lines from `wr` contain the additional keys `verbatim`, which records the most likely corruption and `results`, which records corruption/correction candidates and scores (both from the language model, and the true error model). 163 | 164 | {"verbatim": STRING, 165 | "results": [[STRING, NUMBER, NUMBER|NULL], 166 | ...]} 167 | 168 | Each entry in results is an evaluated candidate, with a candidate string (which may be the same as the target or the verbatim), error model score and language model score. For example if the target is `"can"`: 169 | 170 | {"verbatim": "caj", 171 | "results": [["caj", 0.0, null], 172 | ["can", -3.0, -2.5], 173 | ["cab", -3.0, -2.8], 174 | ["fab", -6.0, -3.2]]} 175 | 176 | In this way, the list of results should include a candidate for the verbatim, and a candidate for the true target, as well as a number of other candidates, which are found by the evaluator to be likely given the corruption, and are included to confuse a language model, forcing it to disambiguate the true target. 177 | 178 | After a error-LM mixture model has been fitted to the log, an additional element is appended to each array, containing the combined score from the combined model (which should not be null). This is the final sort order of candidates. 179 | -------------------------------------------------------------------------------- /lmchallenge/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''LM Challenge - language modelling evaluation suite. 5 | 6 | A set of tools for evaluating language models. We find these tools useful 7 | for making a fair comparison of pure language models, of very different 8 | kinds (e.g. traditional ngram models vs. deep learning), and implemented 9 | in different programming languages / using different frameworks. 10 | 11 | LM Challenge is runnable from the command line `lmc --help`, or from 12 | Python `import lmchallenge as lmc`. This documentation is primarily 13 | for the users from Python, for the command line API and examples, 14 | see the [README on GitHub](https://github.com/Microsoft/LMChallenge). 15 | 16 | # Python example 17 | 18 | Here is a quick example in Python, defining a custom word model, based 19 | on the class `lmchallenge.FilteringWordModel` (for a more general API to 20 | implement, see `lmchallenge.Model` or `lmchallenge.WordModel`). 21 | 22 | #!python 23 | >>> import lmchallenge as lmc 24 | 25 | >>> class MyModel(lmc.FilteringWordModel): 26 | ... def score_word(self, context, candidates): 27 | ... return [(c, -len(c)) for c in candidates] 28 | ... def predict_word_iter(self, context): 29 | ... return [('one', -1), ('two', -2), ('three', -3)] 30 | 31 | >>> my_model = MyModel(n_predictions=3) 32 | 33 | # This is the core LM Challenge API - 'predict' 34 | >>> my_model.predict('', None) 35 | [('one', -1), ('two', -2), ('three', -3)] 36 | >>> my_model.predict('This is ', ['foo', 'a', 'brilliant']) 37 | [('foo', -3), ('a', -1), ('brilliant', -9)] 38 | 39 | To evaluate this model with LM Challenge, we select a challenge - 40 | for example word completion (`wc`), which measures next-word-prediction 41 | hit rate and word completion statistics. 42 | 43 | #!python 44 | >>> log = list(lmc.wc(my_model, ['one potato two potato three'])) 45 | >>> [x['target'] for x in log] 46 | ['one', 'potato', 'two', 'potato', 'three'] 47 | 48 | We now have a _log_ object, which is the core data type of LM Challenge. Here 49 | are a few things you can do with logs: 50 | 51 | #!python 52 | # Compute aggregate stats 53 | >>> stats = lmc.stats.stats(log) 54 | >>> stats['prediction']['hit1'] 55 | 0.2 56 | >>> stats['prediction']['hit3'] 57 | 0.6 58 | >>> stats['fingerprint'] 59 | 'a33f4773' 60 | 61 | # Pretty-print the log 62 | >>> pretty = lmc.pretty.ansi(log) 63 | >>> for line in pretty: 64 | ... print(line) # doctest: +SKIP 65 | one potato two potato three #...except more colourful 66 | 67 | # Filter and compute stats 68 | >>> f_log = lmc.grep.grep('^t|potato', log) 69 | >>> f_stats = lmc.stats.stats(f_log) 70 | >>> f_stats['skipped'] 71 | 0.2 72 | >>> f_stats['prediction']['hit3'] 73 | 0.5 74 | >>> f_stats['fingerprint'] # note: different fingerprint 75 | '2a1ecfab' 76 | 77 | Logs are simply iterables of dictionaries that conform to a log schema, and 78 | are usually stored in JSONlines format: 79 | 80 | #!python 81 | >>> lmc.dump_jsonlines(log, '/tmp/my_log.jsonl') 82 | >>> log = list(lmc.load_jsonlines('/tmp/my_log.jsonl')) 83 | >>> [x['target'] for x in log] 84 | ['one', 'potato', 'two', 'potato', 'three'] 85 | 86 | 87 | # Command Line Interface example 88 | 89 | All of the same features are available on the command line (try 90 | `$ lmc --help`). For a guide and examples, please see the 91 | [README on GitHub](https://github.com/Microsoft/LMChallenge). 92 | ''' 93 | 94 | # Module 95 | 96 | import click 97 | from . import core, diff, grep, pretty, run, stats, validate 98 | from .core.model import FilteringWordModel, Model, WordModel 99 | from .core.common import ( 100 | WORD_TOKENIZER, CHARACTER_TOKENIZER, 101 | load_jsonlines, dump_jsonlines 102 | ) 103 | from .run import wc, we, wr, ce 104 | 105 | __all__ = [ 106 | # submodules 107 | 'core', 108 | 'grep', 109 | 'diff', 110 | 'pretty', 111 | 'run', 112 | 'stats', 113 | 'validate', 114 | 115 | # specifics 116 | 'Model', 117 | 'WordModel', 118 | 'FilteringWordModel', 119 | 'WORD_TOKENIZER', 120 | 'CHARACTER_TOKENIZER', 121 | 'dump_jsonlines', 122 | 'load_jsonlines', 123 | 'wc', 'we', 'wr', 'ce', 124 | ] 125 | 126 | 127 | # Command line interface 128 | 129 | @click.group() 130 | def cli(): 131 | '''The main entry point to LM Challenge. 132 | Use subcommands to perform specific tasks. 133 | ''' 134 | pass 135 | 136 | 137 | # Changes to this list should be synced with setup.py 138 | cli.add_command(diff.cli, 'diff') 139 | cli.add_command(grep.cli, 'grep') 140 | cli.add_command(pretty.cli, 'pretty') 141 | cli.add_command(run.cli, 'run') 142 | cli.add_command(stats.cli, 'stats') 143 | cli.add_command(validate.cli, 'validate') 144 | -------------------------------------------------------------------------------- /lmchallenge/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from . import cli 5 | import sys 6 | 7 | # Pytest's doctest collector runs this file :-( 8 | if len(sys.argv) == 0 or ('pytest' not in sys.argv[0]): 9 | cli() 10 | -------------------------------------------------------------------------------- /lmchallenge/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Core components supporting LM Challenge, defining 5 | `lmchallenge.core.model.Model`, as well as various utilities 6 | for implementing the top-level functionality (which lives in 7 | the other submodules). 8 | ''' 9 | 10 | from .model import FilteringWordModel, Model, WordModel # NOQA 11 | from .common import * # NOQA 12 | -------------------------------------------------------------------------------- /lmchallenge/core/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import re 6 | import sys 7 | import regex 8 | import click 9 | import emoji 10 | import logging 11 | import json 12 | import importlib 13 | import gzip 14 | import contextlib 15 | import itertools as it 16 | 17 | 18 | WORD_TOKENIZER = regex.compile( 19 | emoji.get_emoji_regexp().pattern + 20 | '''|[\p{L}\p{N}\p{Pc}\p{Pd}'@#]+|[\p{P}\p{S}]+''' 21 | ) 22 | '''Our basic word tokenizer regex.''' 23 | 24 | 25 | CHARACTER_TOKENIZER = regex.compile( 26 | '.|\n', flags=regex.MULTILINE 27 | ) 28 | '''A Unicode character tokenizer regex.''' 29 | 30 | 31 | def shell_docstring(command, name): 32 | '''Utility for creating docstrings: 33 | 34 | __doc__ += shell_docstring(cli, 'command-name') 35 | ''' 36 | # Comment and indentation to recognize code segment 37 | text = '\n## `$ {}`\n\n #!sh'.format(name) 38 | text += '\n ' + click.Context(command, info_name=name) \ 39 | .get_help().replace('\n', '\n ') 40 | text += '\n' 41 | return text 42 | 43 | 44 | def verbosity(level): 45 | '''Set logging verbosity to this level (0 to 2 inclusive). 46 | ''' 47 | LOG_LEVELS = [ 48 | logging.WARNING, logging.INFO, logging.DEBUG 49 | ] 50 | logging.basicConfig( 51 | format='%(levelname)s\t%(message)s', 52 | level=LOG_LEVELS[min(len(LOG_LEVELS) - 1, level)], 53 | ) 54 | 55 | 56 | def unique_by(iterable, func): 57 | '''General itertools-like function for lazy-uniquing via a key function. 58 | ''' 59 | added = set([]) 60 | for x in iterable: 61 | fx = func(x) 62 | if fx not in added: 63 | added.add(fx) 64 | yield x 65 | 66 | 67 | def all_predicates(*predicates): 68 | '''Combine predicates into a single predicate, testing if all of them match. 69 | ''' 70 | return lambda x: all(predicate(x) for predicate in predicates) 71 | 72 | 73 | @contextlib.contextmanager 74 | def not_closing(f): 75 | '''A context manager that doesn't call close on the resource. 76 | 77 | For example, use this when you want to run: 78 | 79 | with sys.stdin as f: 80 | # do something with f, closes sys.stdin at the end 81 | 82 | Instead, you can do: 83 | 84 | with not_closing(sys.stdin) as f: 85 | # do something with f, not closing sys.stdin at the end 86 | 87 | ''' 88 | yield f 89 | 90 | 91 | def auto_open(filename, mode='rt'): 92 | '''Open a file, and return it (should be used in a `with`). 93 | 94 | `filename` -- `string` -- path to a file (or gzip), or `"-"` for 95 | stdin/stdout 96 | 97 | `return` -- `file` -- performing gzip decoding if appopriate 98 | ''' 99 | if filename == '-': 100 | if 'r' in mode: 101 | return not_closing(sys.stdin) 102 | elif '+' in mode: 103 | raise ValueError('Cannot return stdout/stdin with mode "r+"') 104 | else: 105 | return not_closing(sys.stdout) 106 | elif filename.endswith('.gz') or filename.endswith('.gzip'): 107 | return gzip.open(filename, mode) 108 | else: 109 | return open(filename, mode) 110 | 111 | 112 | def load_jsonlines(filename): 113 | '''Generate json objects from a JSONlines file. 114 | 115 | Note that this relies on the iterator being exhausted, or going out of 116 | scope, in order to close the file. 117 | 118 | `filename` -- `string` -- path to a file (jsonlines, or gzipped jsonlines), 119 | or `"-"` for stdin 120 | ''' 121 | with auto_open(filename) as f: 122 | for line in f: 123 | yield json.loads(line.rstrip('\r\n')) 124 | 125 | 126 | def dump_jsonlines(data, filename='-'): 127 | '''Dump data to stdout in jsonlines format. 128 | 129 | `data` -- `iterable(dict)` -- data to dump 130 | 131 | `filename` -- `string` -- destination to write (jsonlines, or gzipped 132 | jsonlines), or `"-"` for stdout 133 | ''' 134 | with auto_open(filename, 'wt') as f: 135 | for d in data: 136 | f.write(json.dumps(d, sort_keys=True) + '\n') 137 | 138 | 139 | def flatten_keys(d, separator='.'): 140 | '''Flatten a nested dictionary, using 'separator' to separate keys in the 141 | result. For example: 142 | 143 | {'id': {'name': 'James Bond', 'code': 0x07}, 'job': 'Spy'} 144 | = flatten_keys => 145 | {'id.name': 'James Bond', 'id.code': 0x07, 'job: 'Spy'} 146 | ''' 147 | result = {} 148 | 149 | def flatten(d, prefix): 150 | for k, v in d.items(): 151 | if isinstance(v, dict): 152 | flatten(v, prefix + k + separator) 153 | else: 154 | result[prefix + k] = v 155 | flatten(d, '') 156 | return result 157 | 158 | 159 | def rank(items, item, max_rank=None): 160 | '''Find the rank of 'item' in the list 'items', 161 | returning None if the item is missing, where the first element is 162 | considered rank=1. 163 | 164 | `items` -- `list` -- to search through 165 | 166 | `item` -- `any` -- target 167 | 168 | `max_rank` -- `int` or `None` -- stop the search early at this rank; 169 | if not found, return `None` 170 | 171 | `return` -- `int` or `None` -- `rank >= 1`, or `None` if the item was 172 | not found 173 | ''' 174 | try: 175 | stop = max_rank if max_rank is not None else len(items) 176 | return 1 + items.index(item, 0, stop) 177 | except ValueError: 178 | return None 179 | 180 | 181 | def sort_with_override(items, *first_items): 182 | '''Sort a list, but move 'first_items' to the front in the given order. 183 | ''' 184 | primary_order = {k: i for i, k in enumerate(first_items)} 185 | # use a tuple (primary_order, item) as the sort value, which will 186 | # move items matching primary_order to the front (as a major sort index) 187 | return sorted(items, key=lambda item: ( 188 | primary_order.get(item, len(primary_order)), 189 | item 190 | )) 191 | 192 | 193 | def peek(iterable): 194 | '''Get the first item out of an iterable, then reattach it, so you can 195 | dispatch based on the first item, then process all items. 196 | 197 | `iterable` -- `iterable` -- an iterable or collection of items 198 | 199 | `return` -- `(object, iterable)` -- a pair (first_item, iterable) 200 | where iterable contains all items (including the first) 201 | ''' 202 | iterable = iter(iterable) 203 | try: 204 | first_item = next(iterable) 205 | # rebuid an iterable of all items 206 | all_items = it.chain([first_item], iterable) 207 | return (first_item, all_items) 208 | except StopIteration: 209 | return (None, ()) 210 | 211 | 212 | def is_selected(datum): 213 | '''Is this datum selected (note that a missing 'select' key implicitly means 214 | it should be selected. 215 | ''' 216 | return datum.get('select', True) 217 | 218 | 219 | def autodetect_input(data): 220 | '''Convert plain text input data to the dictionary-based format. 221 | 222 | `data` -- `iterable(dict)` or `iterable(string)` 223 | 224 | `return` -- `iterable(dict)` -- if `data` is plain strings, each 225 | dictionary is `{"text": line}`, otherwise return `data` 226 | unchanged 227 | ''' 228 | first, data = peek(data) 229 | if isinstance(first, str): 230 | # auto-detection if passed an iterable of plain strings 231 | return (dict(text=line) for line in data) 232 | elif isinstance(first, dict): 233 | return data 234 | else: 235 | raise ValueError( 236 | 'unexpected data item {} (expected str or dict)'.format(first)) 237 | 238 | 239 | def zip_combine(common_keys, dict_iterables): 240 | '''Combine a set of iterables, checking that they have identical values 241 | for `common_keys`, nesting any other keys under the iterable's name. 242 | 243 | e.g. `zip_combine(["n"], dict(x=xs, y=ys))` 244 | 245 | | x | y | result | 246 | |----------------|----------------|----------------------------------| 247 | | {n:1, bar:"a"} | {n:1, bar:"b"} | {n:1, x:{bar:"a"}, y:{bar:"b"}} | 248 | | {n:2, bar:"a"} | {n:2} | {n:1, x:{bar:"a"}, y:{}} | 249 | | {n:3, bar:"a"} | {n:4} | throws ValueError | 250 | 251 | `common_keys` -- `list(string)` -- a list of keys that should be equal 252 | in the zipped dicts 253 | 254 | `dict_iterables` -- `dict(string -> iterable)` -- the iterables to be 255 | zipped together with string names 256 | 257 | `return` -- `generator(dict)` -- where the keys of each item are 258 | `common_keys + dict_iterables.keys()` 259 | ''' 260 | common_keys = set(common_keys) 261 | for items in zip(*dict_iterables.values()): 262 | # The first iterable defines the expected values for common_keys 263 | result = {k: items[0][k] for k in common_keys if k in items[0]} 264 | for name, item in zip(dict_iterables.keys(), items): 265 | # Check validity 266 | for k in common_keys: 267 | if result.get(k) != item.get(k): 268 | raise ValueError( 269 | 'zip_combine mismatch between {} and {}' 270 | ' ("{}": {} != {})'.format( 271 | next(iter(dict_iterables.keys())), name, 272 | k, result.get(k), item.get(k))) 273 | # Match - add in the result 274 | result[name] = {k: v 275 | for k, v in item.items() 276 | if k not in common_keys} 277 | yield result 278 | 279 | 280 | def zip_logs(**data): 281 | '''Zip a dictionary of LM Challenge logs together, failing if the logs 282 | don't "match up" (i.e. were generated from different source data). 283 | 284 | The keys that must match (user, character, message, token, target, select) 285 | are returned in the root element of each result, and the log-specific 286 | results are included under the name of that log. 287 | 288 | `data` -- `dict(string -> data)` -- named logs to be zipped together 289 | 290 | `return` -- `generator(dict)` -- zipped logs: 291 | 292 | {"user", "character", "message", "token", "target", "select", 293 | "log_1_name": {"logp"|"completions"|"results"...}, 294 | "log_2_name": {"logp"|"completions"|"results"...}} 295 | ''' 296 | return zip_combine( 297 | ["user", "character", "message", "token", "target", "select"], 298 | data) 299 | 300 | 301 | class JsonParam(click.ParamType): 302 | '''Click parameter type for parsing JSON. 303 | If the parameter is a valid filename, assumes that it is a path to a json 304 | file, and reads that file instead. 305 | ''' 306 | name = 'json' 307 | 308 | def convert(self, value, param, ctx): 309 | try: 310 | if os.path.exists(value): 311 | with auto_open(value) as f: 312 | return json.load(f) 313 | else: 314 | return json.loads(value) 315 | except ValueError as e: 316 | self.fail(str(e)) 317 | 318 | def get_metavar(self, param): 319 | return 'JSON' 320 | 321 | 322 | class ParamChoice(click.ParamType): 323 | '''Like `click.Choice`, but looks up attributes on specific subclasses. 324 | 325 | To subclass, define: 326 | `name` - the descriptive name for the parameter 327 | `choices` - a list of strings, each of which is a valid attr name 328 | ''' 329 | 330 | def convert(self, value, param, ctx): 331 | if value in type(self).choices: 332 | return getattr(self, value) 333 | else: 334 | self.fail('expected one of {%s}, actually "%s"' % ( 335 | ', '.join(map(repr, type(self).choices)), 336 | value 337 | )) 338 | 339 | def get_metavar(self, param): 340 | return '(%s)' % ('|'.join(type(self).choices)) 341 | 342 | 343 | class ChallengeChoice(ParamChoice): 344 | '''Select processing to run on a generated log. 345 | ''' 346 | name = 'challenge' 347 | choices = ['auto', 'completion', 'entropy', 'reranking'] 348 | 349 | @classmethod 350 | def auto(cls, *data, **args): 351 | first, data = zip(*list(peek(d) for d in data)) 352 | 353 | is_completion = all('completions' in x for x in first) 354 | is_entropy = all('logp' in x for x in first) 355 | is_reranking = all('results' in x for x in first) 356 | if sum([is_completion, is_entropy, is_reranking]) != 1: 357 | raise Exception('Cannot infer log type from data') 358 | 359 | if is_completion: 360 | return cls.completion(*data, **args) 361 | elif is_entropy: 362 | return cls.entropy(*data, **args) 363 | elif is_reranking: 364 | return cls.reranking(*data, **args) 365 | 366 | @staticmethod 367 | def completion(data, **args): 368 | raise NotImplementedError 369 | 370 | @staticmethod 371 | def entropy(data, **args): 372 | raise NotImplementedError 373 | 374 | @staticmethod 375 | def reranking(data, **args): 376 | raise NotImplementedError 377 | 378 | 379 | def single_log(logs): 380 | '''When using @click.argument('log', nargs=-1), this limits the log to zero 381 | (for stdin) or one log file. 382 | 383 | `logs` -- `list(string)` -- a list of log file names 384 | 385 | `return` -- `string` -- exactly one log file name 386 | 387 | `raise` -- `ValueError` -- if more than one file is passed 388 | ''' 389 | if len(logs) == 0: 390 | return '-' 391 | elif len(logs) == 1: 392 | return logs[0] 393 | else: 394 | raise ValueError('Can only process zero or one log files') 395 | 396 | 397 | def qualified_name(x): 398 | '''Return the qualified name of 'x' (including the module). 399 | The format is: `module.submodule:attr.subattr`. 400 | ''' 401 | return '%s:%s' % (x.__module__, x.__qualname__) 402 | 403 | 404 | QUALIFIED_NAME_PATTERN = re.compile('^([^: ]+):([^: ]+)$') 405 | 406 | 407 | def is_qualified_name(name): 408 | '''Determine if this is a qualified name of a type (with module). 409 | ''' 410 | return QUALIFIED_NAME_PATTERN.match(name) is not None 411 | 412 | 413 | def lookup_qualified_name(name, base_package=None): 414 | '''Return the attribute for the qualified 'name' 415 | (which should be module.submodule:attr.subattr, 416 | e.g. from `qualified_name`). 417 | ''' 418 | m = QUALIFIED_NAME_PATTERN.match(name) 419 | if m is None: 420 | raise ValueError('could not parse qualified name "%s"' % name) 421 | 422 | module_name = m.group(1) 423 | attr_names = m.group(2) 424 | module = importlib.import_module(module_name, base_package) 425 | try: 426 | obj = module 427 | for attr_name in attr_names.split('.'): 428 | obj = getattr(obj, attr_name) 429 | return obj 430 | except AttributeError: 431 | raise AttributeError( 432 | 'module %s has no attribute %r' 433 | % (module, attr_names)) 434 | 435 | 436 | class AnsiRender: 437 | '''A helper for rendering in ANSI color codes. 438 | Usage: 439 | 440 | r = AnsiRender(sys.stdout) 441 | r.color(r.RED, bold=True) 442 | r.write("Hello\n") 443 | r.default() 444 | ''' 445 | 446 | BLACK = 0 447 | RED = 1 448 | GREEN = 2 449 | YELLOW = 3 450 | BLUE = 4 451 | MAGENTA = 5 452 | CYAN = 6 453 | WHITE = 7 454 | DEFAULT = 9 455 | 456 | def __init__(self, outf): 457 | self.f = outf 458 | self.index = self.DEFAULT 459 | self.bold = False 460 | 461 | def default(self): 462 | self.color(self.DEFAULT, False) 463 | 464 | def color(self, index, bold): 465 | if self.bold and not bold: 466 | self.f.write(u'\x1b[0;%dm' % (30 + index)) 467 | elif bold and not self.bold: 468 | self.f.write(u'\x1b[1;%dm' % (30 + index)) 469 | elif self.index != index: 470 | self.f.write(u'\x1b[%dm' % (30 + index)) 471 | self.index = index 472 | self.bold = bold 473 | 474 | def write(self, s): 475 | self.f.write(s) 476 | 477 | def close(self): 478 | self.f.close() 479 | -------------------------------------------------------------------------------- /lmchallenge/core/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Utilities for corrupting and correcting text, for use with `lmchallenge.wr`. 5 | 6 | Currently the implemented corruption is very simple - corrupt each character 7 | with a low probability of it being replaced by another ASCII letter (or in-word 8 | punctuation character). 9 | ''' 10 | 11 | import random 12 | import math 13 | import string 14 | import heapq 15 | 16 | DEFAULT_CONFIG = dict( 17 | p_anykey=0.1, 18 | error_chars=string.ascii_letters + '_-#@' 19 | ) 20 | 21 | 22 | def corrupt(config, word, rand=random): 23 | '''Generate a corrupted version of a word, as if typed by a sloppy typist. 24 | ''' 25 | p_anykey = config['p_anykey'] 26 | error_chars = config['error_chars'] 27 | return ''.join( 28 | rand.choice(error_chars) if rand.random() < p_anykey else ch 29 | for ch in word 30 | ) 31 | 32 | 33 | def score(config, input_word, word): 34 | '''Return an approximate score for this word, given the error model of 35 | 'config'. 36 | ''' 37 | p_anykey = config['p_anykey'] 38 | n_correct = sum(a == b for a, b in zip(input_word, word)) 39 | return n_correct * math.log(1 - p_anykey) + \ 40 | (len(word) - n_correct) * math.log(p_anykey) 41 | 42 | 43 | class Search: 44 | '''Functor for finding a list of nearby candidates to a corrupted word. 45 | ''' 46 | def __init__(self, words): 47 | self.words = {} 48 | for w in words: 49 | words_l = self.words.get(len(w)) 50 | if words_l is None: 51 | words_l = [] 52 | self.words[len(w)] = words_l 53 | words_l.append(w) 54 | 55 | @staticmethod 56 | def _count_matches(a, b): 57 | n = 0 58 | for ch_a, ch_b in zip(a, b): 59 | if ch_a == ch_b: 60 | n += 1 61 | return n 62 | 63 | def __call__(self, input_word, n): 64 | return heapq.nlargest( 65 | n, self.words.get(len(input_word), []), 66 | key=lambda w: Search._count_matches(input_word, w) 67 | ) 68 | -------------------------------------------------------------------------------- /lmchallenge/core/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import subprocess 6 | import shlex 7 | import regex 8 | import itertools as it 9 | from . import common 10 | 11 | __doc__ = '''Core LM Challenge Model APIs for LMC. 12 | 13 | `lmchallenge.core.model.Model` documents the core API, which can be 14 | implemented by subclassing, or by duck-typing the same API. 15 | ''' 16 | 17 | 18 | class Model: 19 | '''Base class for implementing the Model API for LM Challenge. 20 | 21 | **Subclasses must implement:** 22 | 23 | - `lmchallenge.core.model.Model.predict` 24 | 25 | **Optional:** 26 | 27 | - `lmchallenge.core.model.Model.train` 28 | - `lmchallenge.core.model.Model.clear` 29 | - `__enter__` 30 | - `__exit__` 31 | ''' 32 | 33 | def predict(self, context, candidates): 34 | '''Get text completions (or score candidates) following a context. 35 | 36 | `context` -- `string` -- preceding text that should be treated as 37 | fixed, all predictions and candidates follow this. 38 | 39 | `candidates` -- `list(string)` or `None` -- optional candidates to 40 | score following `context`. 41 | If `None`, the model should generate the most likely 42 | candidates itself to score, otherwise it need only 43 | return results from this list (with scores). 44 | 45 | `return` -- `list((string, float))` -- ordered list of result 46 | completions & scores. 47 | The list should be ordered from most to least likely - 48 | in general, that should correspond to descending score 49 | order. 50 | In some cases (e.g. for computing entropy), it is 51 | important that the score is a normalized log-probability. 52 | Each completion (as candidates) should be treated as if it 53 | follows `context` directly (i.e. if `context` stops in the 54 | middle of a word, the results are completions of that 55 | word). 56 | 57 | For example (e.g. 1): 58 | 59 | predict("I am the", None) 60 | -> [("re", -0.5), ("y", -4)] 61 | 62 | This means the model thinks the completion "the" -> "there" is most 63 | likely, followed by the completion "they". 64 | 65 | For example (e.g. 2): 66 | 67 | predict("I am ", ["they", "the", "inevitably"]) 68 | -> [("the", -0.5), ("they", -4)] 69 | 70 | This means the model has been asked to score three words, of which: 71 | 72 | - "the" is the most likely, with a high score 73 | - "they" is unlikely, with a low score 74 | - "inevitably" is out-of-vocab, so is not scored & returned at all 75 | 76 | ''' 77 | raise NotImplementedError 78 | 79 | def train(self, text): 80 | '''Adapt the model to the current user, providing text that has been 81 | entered. 82 | 83 | `text` -- `string` -- the contents of a single message, for the current 84 | user 85 | 86 | User-specific language information should be aggregated across `train` 87 | calls until `clear` is called. 88 | 89 | (Default implementation: do nothing) 90 | ''' 91 | pass 92 | 93 | def clear(self): 94 | '''Reset all trained state in the model (i.e. we're about to start with 95 | a new user). 96 | 97 | After calling `clear`, the results from `predict` should be as if the 98 | `Model` has been newly created. 99 | 100 | (Default implementation: do nothing) 101 | ''' 102 | pass 103 | 104 | def __enter__(self): 105 | '''Default implementation.''' 106 | return self 107 | 108 | def __exit__(self, type, value, traceback): 109 | '''Default implementation: do nothing.''' 110 | pass 111 | 112 | def run_loop(self, 113 | input_stream=sys.stdin, 114 | output_stream=sys.stdout, 115 | error_stream=sys.stderr): 116 | '''Run the model as a pipeable predictor process between 117 | `input_stream` and `output_stream`. 118 | 119 | `input_stream` -- `stream` -- commands to run (default: STDIN) 120 | 121 | `output_stream` -- `stream` -- prediction output (default: STDOUT) 122 | 123 | `error_stream` -- `stream` -- error message output (default: STDERR) 124 | 125 | This method does not return until `input_stream` is exhausted. 126 | 127 | E.g. 128 | 129 | class MyModel(Model): 130 | ... 131 | 132 | if __name__ == '__main__': 133 | model = MyModel() 134 | model.run_loop() 135 | ''' 136 | for line in input_stream: 137 | parts = line.strip('\n').split('\t') 138 | cmd = parts[0] 139 | if cmd == 'predict': 140 | context = parts[1] if 2 <= len(parts) else '' 141 | candidates = parts[2:] if 3 <= len(parts) else None 142 | results = ((x, s) 143 | for x, s in self.predict(context, candidates) 144 | if s is not None) 145 | response = '\t'.join('%s\t%f' % (candidate, score) 146 | for candidate, score in results) + '\n' 147 | output_stream.write(response) 148 | elif cmd == 'train': 149 | self.train(parts[1]) 150 | elif cmd == 'clear': 151 | self.clear() 152 | else: 153 | error_stream.write('Unrecognized command "' + cmd + '"\n') 154 | output_stream.flush() 155 | 156 | 157 | class WordModel(Model): 158 | '''Optional helper subclass for defining a word-by-word prediction model, 159 | based on a regex tokenizer. 160 | 161 | **Subclasses must implement:** 162 | 163 | - `lmchallenge.core.model.WordModel.predict_word` 164 | - `lmchallenge.core.model.WordModel.score_word` 165 | 166 | **Optional:** 167 | 168 | - `lmchallenge.core.model.WordModel.train_word` 169 | ''' 170 | def __init__(self, token_pattern=None): 171 | if token_pattern is None: 172 | self._tokenizer = common.WORD_TOKENIZER 173 | else: 174 | self._tokenizer = regex.compile(token_pattern) 175 | 176 | def predict(self, context, candidates): 177 | tokens = list(self._tokenizer.finditer(context)) 178 | if len(tokens) and tokens[-1].end() == len(context): 179 | # there is an "in-progress" word 180 | prefix = tokens.pop(-1).group(0) 181 | else: 182 | prefix = '' 183 | context_tokens = [m.group(0) for m in tokens] 184 | 185 | if candidates is None: 186 | return self.predict_word(context_tokens, prefix) 187 | else: 188 | return self.score_word( 189 | context_tokens, 190 | [prefix + candidate for candidate in candidates]) 191 | 192 | def predict_word(self, context, prefix): 193 | '''Predict a next word, or complete the current word. 194 | 195 | `context` -- `list(string)` -- word tokens in the context 196 | 197 | `prefix` -- `string` -- the prefix of the word being typed (if empty, 198 | returns next word predictions) 199 | 200 | `return` -- `list((string, float))` -- a list of pairs (suffix, score) 201 | ''' 202 | raise NotImplementedError 203 | 204 | def score_word(self, context, candidates): 205 | '''Score a set of candidates which follow a context. 206 | 207 | `context` -- `list(string)` -- word tokens in the context 208 | 209 | `candidates` -- `set(string)` -- should return scores for each of 210 | these, if possible 211 | 212 | `return` -- `list((string, float))` -- a list of pairs 213 | (candidate, score) 214 | ''' 215 | raise NotImplementedError 216 | 217 | def train(self, text): 218 | return self.train_word(self._tokenizer.findall(text)) 219 | 220 | def train_word(self, text): 221 | '''Add this sequence of words to a user-adaptive model. 222 | 223 | `text` -- `list(string)` -- word tokens in the message 224 | 225 | (Default implementation: do nothing.) 226 | ''' 227 | pass 228 | 229 | 230 | class FilteringWordModel(WordModel): 231 | '''Specialization of WordModel, which automatically filters prefixes 232 | & limits results. 233 | 234 | **Subclasses must implement:** 235 | 236 | - `lmchallenge.core.model.FilteringWordModel.predict_word_iter` 237 | - `lmchallenge.core.model.WordModel.score_word` 238 | ''' 239 | def __init__(self, n_predictions, filter_pattern='.', **args): 240 | '''Create a filtering word model. 241 | 242 | `n_predictions` -- `int` -- how many predictions/completions to return 243 | (does not apply when scoring candidates). 244 | 245 | `filter_pattern` -- `string` -- a regex pattern to apply to filter 246 | results (does not apply when scoring candidates). 247 | A result will be allowed if the pattern is matched 248 | anywhere in the string. 249 | 250 | `**args` -- see `lmchallenge.core.model.WordModel` 251 | ''' 252 | super().__init__(**args) 253 | self.n_predictions = n_predictions 254 | self.filter_xp = regex.compile(filter_pattern) 255 | 256 | def predict_word_iter(self, context): 257 | '''As per `lmchallenge.core.model.WordModel.predict_word`, but should 258 | return a lazy generator/iterator of next words, which may include 259 | duplicates. 260 | 261 | `context` -- `list(string)` -- list of preceding tokens 262 | 263 | `return` -- `generator(list((string, float)))` -- lazy sequence of 264 | pairs (word, score) 265 | ''' 266 | raise NotImplementedError 267 | 268 | def predict_word(self, context, prefix): 269 | results = self.predict_word_iter(context) 270 | filter_results = ( 271 | (w[len(prefix):], s) 272 | for w, s in results 273 | if self.filter_xp.search(w) is not None 274 | and len(prefix) < len(w) and w.startswith(prefix)) 275 | unique_results = common.unique_by(filter_results, lambda e: e[0]) 276 | top_results = it.islice(unique_results, self.n_predictions) 277 | return list(top_results) 278 | 279 | 280 | class ShellModel(Model): 281 | '''Defines a language model that proxies calls to a subprocess, 282 | over a Unix pipe. 283 | ''' 284 | def __init__(self, cmd, opts): 285 | '''Open a pipe to the model. 286 | 287 | `cmd` -- `string` -- shell command to run in a subprocess 288 | 289 | `opts` -- `dict` -- arguments to send to the subprocess. 290 | The key `"positional"` can refer to a list of positional 291 | arguments. 292 | The key `"verbose"` is used to control model verbosity, 293 | as well as being sent to the subprocess. 294 | All other keys are sent to the subprocess command as 295 | `"--KEY VALUE"` (unless they already start with "-", in 296 | which case just `"KEY VALUE"`). 297 | ''' 298 | self.verbose = opts.get('verbose', False) 299 | cmd_positional = ' '.join(opts.get('positional', [])) 300 | cmd_options = ' '.join( 301 | '%s %s' % (key if key.startswith('-') else ('--' + key), 302 | shlex.quote(str(value))) 303 | for key, value in opts.items() 304 | if key != 'positional' 305 | ) 306 | self.cmd = '%s %s %s' % (cmd, cmd_positional, cmd_options) 307 | self._debug('$ %s' % self.cmd) 308 | self.proc = subprocess.Popen(self.cmd, shell=True, 309 | stdin=subprocess.PIPE, 310 | stdout=subprocess.PIPE) 311 | self.proc_in = open(self.proc.stdin.fileno(), 'w', encoding='utf-8') 312 | self.proc_out = open(self.proc.stdout.fileno(), 'r', encoding='utf-8') 313 | self._check_return_code() 314 | 315 | def __exit__(self, type, value, traceback): 316 | self._close() 317 | 318 | def _debug(self, line): 319 | if self.verbose: 320 | sys.stderr.write(line + '\n') 321 | 322 | def _close(self): 323 | '''Finish with the process & shut it down.''' 324 | self.proc.communicate() 325 | self._check_return_code() 326 | 327 | def _check_return_code(self): 328 | '''Check that our model process hasn't errored out.''' 329 | rc = self.proc.returncode 330 | if rc is not None and rc != 0: 331 | raise subprocess.CalledProcessError(rc, self.cmd) 332 | 333 | def _send_command(self, command): 334 | '''Issue a tab-delimited command to the model process.''' 335 | self._debug('#> %s' % (' '.join(command))) 336 | self._check_return_code() 337 | self.proc_in.write(u'\t'.join(command) + u'\n') 338 | self.proc_in.flush() 339 | 340 | def predict(self, context, candidates): 341 | command = ['predict', context] 342 | if candidates is not None: 343 | command += candidates 344 | self._send_command(command) 345 | self._check_return_code() 346 | candidates_and_scores = [ 347 | s for s in self.proc_out.readline().strip('\n').split('\t') 348 | ] 349 | self._debug('#< %s' % (' '.join(candidates_and_scores))) 350 | pairs = [(candidates_and_scores[i*2], 351 | float(candidates_and_scores[i*2+1])) 352 | for i in range(len(candidates_and_scores)//2)] 353 | return pairs 354 | 355 | def train(self, text): 356 | self._send_command(['train', text]) 357 | 358 | def clear(self): 359 | self._send_command(['clear']) 360 | -------------------------------------------------------------------------------- /lmchallenge/core/reranking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Utilities for the optimization and evaluation of reranking models. 5 | ''' 6 | 7 | import numpy as np 8 | import scipy.optimize 9 | 10 | 11 | # Helpers 12 | 13 | def replace_none_vector(values): 14 | '''Convert 'values' to a vector, replacing None with -infinity. 15 | ''' 16 | return np.array([(v if v is not None else -np.inf) for v in values], 17 | dtype=np.float32) 18 | 19 | 20 | def jagged_matrix(values, n): 21 | '''Convert a jagged Python array to a matrix, replacing None & truncated 22 | values with -infinity. 23 | 24 | `values` -- `list(list(float))` -- a list of 'V' lists, of maximum size 'n' 25 | 26 | `n` -- `int` -- the number of columns in the result 27 | 28 | `return` -- `array[V, n; float]` -- matrix containing `values` 29 | ''' 30 | result = np.full((len(values), n), -np.inf, dtype=np.float32) 31 | for i, row in enumerate(values): 32 | result[i, :len(row)] = replace_none_vector(row) 33 | return result 34 | 35 | 36 | def count_correct(scores): 37 | '''Compute number of correct (rank=1) scores for a matrix of scores, where 38 | the score of the intended "target" is at index 1 of each row. 39 | 40 | N.B. Uses greater-than rather than greater-than-or-equal, 41 | although this is possibly a bit harsh (you could have achieved 42 | the correct via some arbitrary tie-breaking function). 43 | 44 | `scores` -- `array[N, C; float]` -- scores of all terms, where 45 | `scores[:, 0]` are the intended target's scores 46 | 47 | `return` -- `int` -- number of correct (rank=1) results 48 | (in the range `[0, N]`) 49 | ''' 50 | return int((scores[:, 0] > scores[:, 1:].max(axis=1)).sum()) 51 | 52 | 53 | # Reranking 54 | 55 | class RerankingModel: 56 | '''A model that is capable of combining error & language model scores 57 | to rerank candidates (e.g. for the goal of optimizing combined ranking 58 | accuracy). 59 | ''' 60 | @classmethod 61 | def guess(cls, error, lm): 62 | '''Return the initial guess at a good set of arguments. 63 | 64 | `error` -- `array[N; float]` -- example error scores 65 | 66 | `lm` -- `array[N; float]` -- example language model scores 67 | 68 | `return` -- `dict` -- `{"arg_name": initial_value}` 69 | ''' 70 | raise NotImplementedError 71 | 72 | def __init__(self, **args): 73 | self.args = args 74 | for k, v in args.items(): 75 | setattr(self, k, v) 76 | 77 | def __call__(self, error, lm): 78 | '''Evaluate the reranking model for the given error & LM scores. 79 | 80 | `error` -- `array[*; float]` -- error scores (any shape permitted) 81 | 82 | `lm` -- `array[*; float]` -- language model scores (any shape 83 | permitted, but must match `error`) 84 | 85 | `return` -- `array[*; float]` -- combined scores from the model (same 86 | shape as `error` & `lm`) 87 | ''' 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | def optimize(cls, error, lm): 92 | '''Optimize a reranking model for Hit@1 disambiguation. 93 | 94 | `return` -- `lmchallenge.core.reranking.RerankingModel` -- 95 | an optimized model instance 96 | ''' 97 | guess = cls.guess(error=error, lm=lm) 98 | 99 | def create(argv): 100 | return cls(**{k: v for k, v in zip(guess.keys(), argv)}) 101 | 102 | return create(scipy.optimize.fmin( 103 | lambda argv: -count_correct(create(argv)(error, lm)), 104 | x0=list(guess.values()), 105 | disp=False, 106 | )) 107 | 108 | 109 | class InterpolationRerankingModel(RerankingModel): 110 | '''Implements an interpolation-with-minimum combination model: 111 | 112 | score = max(alpha * lm_score, beta) + (1 - alpha) * error_score 113 | 114 | Hyperparameters: 115 | 116 | `alpha` -- `float` -- how much to trust the language model 117 | 118 | `beta` -- `float` -- the minimum contribution from the language model 119 | (e.g. for protection against OOV) 120 | ''' 121 | @classmethod 122 | def guess(cls, error, lm): 123 | return dict( 124 | alpha=0.5, 125 | beta=0.5 * float(np.median(lm[lm != -np.inf])), 126 | ) 127 | 128 | def __call__(self, error, lm): 129 | return ( 130 | (1 - self.alpha) * error + 131 | np.maximum(self.alpha * (lm if lm is not None else float('-inf')), 132 | self.beta) 133 | ) 134 | 135 | def __str__(self): 136 | return 'score = {:.3g} * error + max({:.3g} * lm, {:.3g})'.format( 137 | 1 - self.alpha, 138 | self.alpha, 139 | self.beta, 140 | ) 141 | -------------------------------------------------------------------------------- /lmchallenge/core/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /lmchallenge/core/tests/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .. import common 5 | import itertools as it 6 | import emoji 7 | import io 8 | import pytest 9 | 10 | 11 | def test_unique_by(): 12 | # empty 13 | assert list(common.unique_by([], lambda x: x)) == [] 14 | 15 | # unique strings by length 16 | assert list(common.unique_by( 17 | ["one", "two", "three", "four", "five", "six"], len) 18 | ) == ["one", "three", "four"] 19 | 20 | # check we're lazy (pass an infinite generator) 21 | assert list( 22 | it.islice(common.unique_by(it.count(0), lambda x: x), 0, 5) 23 | ) == [0, 1, 2, 3, 4] 24 | 25 | 26 | def test_not_closing(): 27 | s = io.StringIO() 28 | assert not s.closed 29 | 30 | with common.not_closing(s) as f: 31 | f.write('hello') 32 | assert not s.closed 33 | 34 | # If you don't wrap in "not_closing" the with statement will 35 | # close the resource 36 | with s as f: 37 | f.write(' again') 38 | # You can't do this after it is closed 39 | assert f.getvalue() == 'hello again' 40 | assert f.closed 41 | 42 | 43 | def test_peek(): 44 | x, items = common.peek('abcdef') 45 | assert x == 'a' 46 | assert (''.join(items)) == 'abcdef' 47 | 48 | x, items = common.peek(it.count()) 49 | assert x == 0 50 | assert list(it.islice(items, 4)) == [0, 1, 2, 3] 51 | 52 | x, items = common.peek([]) 53 | assert x is None 54 | assert list(items) == [] 55 | 56 | 57 | def test_autodetect_input(): 58 | assert [dict(text="first line"), dict(text="second line")], \ 59 | list(common.autodetect_input(["first line", "second line"])) 60 | 61 | # already dict => unchanged 62 | data_with_user = [dict(user="a", text="first line"), 63 | dict(user="a", text="second line")] 64 | assert data_with_user, list(common.autodetect_input(data_with_user)) 65 | 66 | with pytest.raises(ValueError): 67 | common.autodetect_input([1, 2]) 68 | 69 | 70 | def test_word_tokenizer(): 71 | def tokenize(x): 72 | return [m.group(0) for m in common.WORD_TOKENIZER.finditer(x)] 73 | 74 | assert [] == tokenize('') 75 | 76 | assert ["one", "two", "@DIGITS", "#What_you're_like@home"] \ 77 | == tokenize("one two \n\t@DIGITS #What_you're_like@home") 78 | 79 | assert ["ready4this", "...", ":-)", "yeah-buddy", ":", "??"] \ 80 | == tokenize("ready4this... :-) yeah-buddy: ??") 81 | 82 | assert ["this", "is", "\U0001F4A9", "!"] \ 83 | == tokenize("this is\U0001F4A9!") 84 | 85 | for emo in emoji.UNICODE_EMOJI.keys(): 86 | assert ["pre", emo, "post"] == tokenize("pre {} post".format(emo)) 87 | 88 | 89 | def test_character_tokenizer(): 90 | def tokenize(x): 91 | return [m.group(0) for m in common.CHARACTER_TOKENIZER.finditer(x)] 92 | 93 | assert [] == tokenize('') 94 | assert ['1', '\t', '2', '#', '😀'] == tokenize('1\t2#😀') 95 | assert ['\n', '\n', '\r'] == tokenize('\n\n\r') 96 | 97 | 98 | def test_is_selected(): 99 | assert common.is_selected(dict(target='foo', select=True)) 100 | assert not common.is_selected(dict(target='foo', select=False)) 101 | assert common.is_selected(dict(target='foo')) 102 | 103 | 104 | def test_zip_combine(): 105 | # As documented 106 | for x, y, expected in [ 107 | # general case - non-common data is keyed under the name 108 | (dict(n=1, bar="a"), dict(n=1, bar="b"), 109 | dict(n=1, x=dict(bar="a"), y=dict(bar="b"))), 110 | 111 | # non-common data can be different/missing 112 | (dict(n=2, bar="a"), dict(n=2), 113 | dict(n=2, x=dict(bar="a"), y=dict())), 114 | 115 | # different common data generates an error 116 | (dict(n=3, bar="a"), dict(n=4), 117 | ValueError), 118 | 119 | # mismatched-missing common data generates an error 120 | (dict(n=3, bar="a"), dict(bar="a"), 121 | ValueError), 122 | 123 | # matched-missing common data is OK 124 | (dict(bar="a"), dict(bar="b"), 125 | dict(x=dict(bar="a"), y=dict(bar="b"))), 126 | ]: 127 | try: 128 | assert list(common.zip_combine(['n'], dict(x=[x], y=[y]))) \ 129 | == [expected] 130 | except ValueError as e: 131 | assert expected == ValueError, e 132 | 133 | assert list(common.zip_combine(['n'], dict())) == [] 134 | assert list(common.zip_combine(['n'], dict(x=[], y=[]))) == [] 135 | assert list(common.zip_combine( 136 | ['a'], 137 | dict(x=[dict(a=1), dict(a=2)], y=[dict(a=1)]))) \ 138 | == [dict(a=1, x={}, y={})], \ 139 | 'short iterables are truncated, as per zip()' 140 | 141 | 142 | def test_zip_logs(): 143 | base = dict(user='a', character=1, message=2, token=3, 144 | target='foo', select=True) 145 | transforms = dict(user='b', character=10, message=20, token=30, 146 | target='bar', select=False) 147 | 148 | log_a = [dict(logp=-2.5, **base)] 149 | log_b = [dict(logp=-3.5, **base)] 150 | for key, new_value in transforms.items(): 151 | new_base = base.copy() 152 | new_base[key] = new_value 153 | log_a.append(dict(logp=-2.5, **new_base)) 154 | log_b.append(dict(logp=-3.5, **new_base)) 155 | 156 | result = list(common.zip_logs(a=log_a, b=log_b)) 157 | 158 | assert len(result) == 1 + len(transforms) 159 | 160 | assert result[0] == dict( 161 | a=dict(logp=-2.5), 162 | b=dict(logp=-3.5), 163 | **base) 164 | 165 | 166 | class Foo: 167 | BAR = 123 168 | 169 | 170 | def test_qualified_name(): 171 | assert common.is_qualified_name("abc.def:Ghi") 172 | 173 | assert not common.is_qualified_name("abc.def") 174 | 175 | assert common.is_qualified_name(common.qualified_name(Foo)) 176 | 177 | assert Foo == common.lookup_qualified_name( 178 | common.qualified_name(Foo)) 179 | 180 | assert Foo == common.lookup_qualified_name( 181 | "lmchallenge.core.tests.test_common:Foo") 182 | 183 | assert 123 == common.lookup_qualified_name( 184 | "lmchallenge.core.tests.test_common:Foo.BAR") 185 | -------------------------------------------------------------------------------- /lmchallenge/core/tests/test_errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .. import errors 5 | from unittest.mock import MagicMock 6 | import random 7 | import math 8 | import string 9 | 10 | 11 | def test_corrupt(): 12 | config = errors.DEFAULT_CONFIG 13 | p_anykey = config['p_anykey'] 14 | rand = MagicMock() 15 | rand.random = MagicMock( 16 | # ok, error, ok, error, error 17 | side_effect=[1, p_anykey-0.001, p_anykey+0.001, 0, 0] 18 | ) 19 | rand.choice = MagicMock( 20 | side_effect=['q', '_', 'Z'] 21 | ) 22 | assert errors.corrupt(config, "hello", rand) == "hql_Z" 23 | 24 | assert errors.score(config, "hql_Z", "hello") \ 25 | == 3 * math.log(p_anykey) + 2 * math.log(1 - p_anykey) 26 | 27 | 28 | def test_corrupt_fuzz(): 29 | config = errors.DEFAULT_CONFIG 30 | for i in range(1000): 31 | word = ''.join(random.choice(string.ascii_letters) 32 | for j in range(random.randint(1, 10))) 33 | 34 | input_word = errors.corrupt(config, word) 35 | assert len(word) == len(input_word) 36 | 37 | input_score = errors.score(config, input_word, input_word) 38 | assert input_score < 0 39 | 40 | score = errors.score(config, input_word, word) 41 | assert score <= input_score 42 | 43 | 44 | def test_search(): 45 | assert errors.Search([])("csn", 3) == [] 46 | assert errors.Search( 47 | ["can", "cs", "dam", "csn", "csna"])("csn", 3) \ 48 | == ["csn", "can", "dam"] 49 | # if there is a tie, the first result in the list should be retained 50 | assert errors.Search( 51 | ["baa", "bba", "aba", "aab", "aaa", "caa"])("aaa", 3) \ 52 | == ["aaa", "baa", "aba"] 53 | 54 | 55 | def test_search_fuzz(): 56 | # these are all very similar, so there is some interesting overlap 57 | words = [ 58 | "bird", "Bird", "bind", "bard", "Aird", "gird", "Hird", "biro", "byrd", 59 | "birr", "bord", "birt", "Gird", "birs", "birh", "find", "died", "bill", 60 | "hard", "fire" 61 | ] 62 | search = errors.Search(words) 63 | config = errors.DEFAULT_CONFIG 64 | for word in words: 65 | for _ in range(10): 66 | corrupt = errors.corrupt(config, word) 67 | top = search(corrupt, 5) 68 | last_top_score = errors.score(config, corrupt, top[-1]) 69 | 70 | for w in words: 71 | score = errors.score(config, corrupt, w) 72 | if w in top: 73 | assert last_top_score <= score 74 | else: 75 | assert score <= last_top_score 76 | -------------------------------------------------------------------------------- /lmchallenge/core/tests/test_performance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .. import errors 5 | import time 6 | 7 | 8 | if __name__ == '__main__': 9 | with open('example/words.txt') as f: 10 | words = set(line.rstrip('\r\n') for line in f) 11 | print('Number of words: %d' % len(words)) 12 | 13 | t0 = time.time() 14 | search = errors.Search(words, 20) 15 | t1 = time.time() 16 | print('Build: %.3f s' % (t1 - t0)) 17 | 18 | # ascending word lengths 19 | total = 0 20 | for word in ['I', 'am', 'the', 'bird', 'which', 'should', 'already', 21 | 'increase', 'published', 'investment']: 22 | t0 = time.time() 23 | print('\t%s -> %s' % (word, ' '.join(search(word)))) 24 | t1 = time.time() 25 | total += t1 - t0 26 | print('Search: %.3f s' % (t1 - t0)) 27 | print('Total: %.3f s' % total) 28 | -------------------------------------------------------------------------------- /lmchallenge/data/viewer.css: -------------------------------------------------------------------------------- 1 | .line { 2 | margin-bottom: 1em; 3 | } 4 | 5 | .word { 6 | display: inline-block; 7 | vertical-align: top; 8 | color: #202020; 9 | white-space: pre; 10 | } 11 | .word.spaced { 12 | margin-left: .5em; 13 | } 14 | 15 | .filtered { 16 | color: #0000a0; 17 | } 18 | .entropy-miss { 19 | color: #800080; 20 | } 21 | .entropy-hit { 22 | font-weight: bold; 23 | } 24 | .wc-predicted { 25 | color: #00a030; 26 | } 27 | .wc-unpredicted { 28 | color: #ff0000; 29 | } 30 | .wc-unpredicted i { 31 | color: #ffa000; 32 | font-style: normal; 33 | } 34 | .selected { 35 | background-color: #ddf; 36 | font-weight: bold; 37 | } 38 | .target-row { 39 | background-color: #a0ffc0; 40 | } 41 | .verbatim-row { 42 | background-color: #ffe0c0; 43 | } 44 | .tip { 45 | text-align: left; 46 | } 47 | .results-table { 48 | font-size: small; 49 | } 50 | .results-table-rank { 51 | font-family: mono; 52 | } 53 | -------------------------------------------------------------------------------- /lmchallenge/data/viewer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | LMC Viewer 4 | 5 | 8 | 11 | 12 | 13 | 14 | 15 |
16 |
17 |
18 |

LMC Viewer

19 |
20 |
21 | 22 |
23 |
24 |
25 | 29 |
30 |
31 | 35 |
36 |
37 |

38 |           
39 |
40 | 41 |
42 |
43 |
44 | 45 |
46 |
47 | 48 | 51 | 54 | 57 | 60 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /lmchallenge/data/viewer.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | // *** A Javascript visualizer for LMC logs *** 5 | 6 | 7 | // Helpers 8 | 9 | function to_fixed(x, n) { 10 | return x === null ? "null" : x.toFixed(n); 11 | } 12 | 13 | function percent(x) { 14 | return to_fixed(100 * x, 1) + " %"; 15 | } 16 | 17 | // Return the log data, grouped by user's messages 18 | // data -- a list of events for tokens 19 | // return -- a list of list of events for each message 20 | // (adds the attribute "skip" which is true if 21 | // the datum is not a consecutive character) 22 | function data_by_line(data) { 23 | var lines = []; 24 | var user = null; 25 | var message = NaN; 26 | var line = null; 27 | var character = null; 28 | for (var i = 0; i < data.length; ++i) { 29 | var d = Object.assign({}, data[i]); 30 | if (d.user === user && d.message === message) { 31 | d.skip = (d.character !== character + 1); 32 | line.push(d); 33 | character = d.character; 34 | } else { 35 | d.skip = false; 36 | line = [d]; 37 | lines.push(line); 38 | user = d.user; 39 | message = d.message; 40 | character = d.character; 41 | } 42 | } 43 | return lines; 44 | } 45 | 46 | function nwp_rank(d) { 47 | return 1 + d.completions[0].indexOf(d.target); 48 | } 49 | 50 | function chars_completed(d) { 51 | for (var i = 0; i < d.completions.length; ++i) { 52 | var rank = 1 + d.completions[i].indexOf(d.target.substr(i)); 53 | if (rank !== 0 && rank <= 2) { 54 | return d.target.length - i; 55 | } 56 | } 57 | return 0; 58 | } 59 | 60 | // Computes summary statistics from the results set, as ratios 61 | function summary_stats(results) { 62 | // Set up aggregators 63 | var stats = {"total": 0, "filter_included": 0}; 64 | if (results[0].verbatim !== undefined) { 65 | stats.wr = { 66 | "inaccurate_incorrect": 0, 67 | "inaccurate_correct": 0, 68 | "accurate_incorrect": 0, 69 | "accurate_correct": 0 70 | }; 71 | } 72 | if (results[0].logp !== undefined) { 73 | stats.entropy = { 74 | "hit": 0, 75 | "entropy": 0 76 | }; 77 | } 78 | if (results[0].completions !== undefined) { 79 | stats.wc = { 80 | "hit1": 0, 81 | "hit3": 0, 82 | "hit10": 0, 83 | "mrr": 0, 84 | "chars_completed": 0 85 | }; 86 | } 87 | 88 | // Compute aggregates 89 | results.forEach(function (d) { 90 | stats.total += 1; 91 | // N.B. null or true mean unfiltered! 92 | if (d.select !== false) { 93 | stats.filter_included += 1; 94 | if (stats.wr) { 95 | var before = d.verbatim === d.target; 96 | var after = d.results[0][0] === d.target; 97 | stats.wr.inaccurate_incorrect += !before && !after; 98 | stats.wr.inaccurate_correct += !before && after; 99 | stats.wr.accurate_incorrect += before && !after; 100 | stats.wr.accurate_correct += before && after; 101 | } 102 | if (stats.entropy) { 103 | stats.entropy.hit += (d.logp !== null); 104 | if (d.logp !== null) { 105 | stats.entropy.entropy -= d.logp; 106 | } 107 | } 108 | if (stats.wc) { 109 | var r0 = nwp_rank(d); 110 | if (r0 !== 0) { 111 | stats.wc.hit1 += (r0 <= 1); 112 | stats.wc.hit3 += (r0 <= 3); 113 | stats.wc.hit10 += (r0 <= 10); 114 | stats.wc.mrr += 1 / r0; 115 | } 116 | stats.wc.chars_completed += chars_completed(d); 117 | } 118 | } 119 | }); 120 | 121 | // "Sum up" stats 122 | if (stats.wr) { 123 | stats.wr.inaccurate_incorrect /= stats.filter_included; 124 | stats.wr.inaccurate_correct /= stats.filter_included; 125 | stats.wr.accurate_incorrect /= stats.filter_included; 126 | stats.wr.accurate_correct /= stats.filter_included; 127 | } 128 | if (stats.entropy) { 129 | stats.entropy.entropy /= stats.entropy.hit; 130 | stats.entropy.hit /= stats.filter_included; 131 | } 132 | if (stats.wc) { 133 | stats.wc.hit1 /= stats.filter_included; 134 | stats.wc.hit3 /= stats.filter_included; 135 | stats.wc.hit10 /= stats.filter_included; 136 | stats.wc.mrr /= stats.filter_included; 137 | stats.wc.chars_completed /= stats.filter_included; 138 | } 139 | stats.filter_included /= stats.total; 140 | 141 | return stats; 142 | } 143 | 144 | 145 | // Rendering 146 | 147 | function render_summary(stats) { 148 | $(".results").empty(); 149 | if (stats.wr) { 150 | $(".results").append($("").addClass("table").addClass("table-bordered") 151 | .append($("") 152 | .append($("") 154 | .append("")) 155 | .append($("") 156 | .append("") 157 | .append($("") 160 | .append("") 161 | .append($("
").addClass("info").text("Filter " + percent(stats.filter_included))) 153 | .append("IncorrectCorrect
Inaccurate").addClass("warning").text(percent(stats.wr.inaccurate_incorrect))) 158 | .append($("").addClass("success").text(percent(stats.wr.inaccurate_correct)))) 159 | .append($("
Accurate").addClass("danger").text(percent(stats.wr.accurate_incorrect))) 162 | .append($("").addClass("active").text(percent(stats.wr.accurate_correct)))) 163 | ); 164 | } 165 | if (stats.entropy) { 166 | $(".results").append($("").addClass("table").addClass("table-bordered") 167 | .append($("") 168 | .append("") 169 | .append($("") 171 | .append("") 172 | .append($("") 174 | .append("") 175 | .append($("
Filter").addClass("info").text(percent(stats.filter_included)))) 170 | .append($("
Hit (after filter)").addClass("warning").text(percent(stats.entropy.hit)))) 173 | .append($("
Entropy").addClass("success").text(to_fixed(stats.entropy.entropy, 2)))) 176 | ); 177 | } 178 | if (stats.wc) { 179 | $(".results").append($("").addClass("table").addClass("table-bordered") 180 | .append($("") 181 | .append("") 182 | .append($("") 184 | .append("") 185 | .append($("") 187 | .append("") 188 | .append($("") 190 | .append("") 191 | .append($("") 193 | .append("") 194 | .append($("") 196 | .append("") 197 | .append($("
Filter").addClass("info").text(percent(stats.filter_included)))) 183 | .append($("
Hit@1").addClass("success").text(percent(stats.wc.hit1)))) 186 | .append($("
Hit@3").addClass("success").text(percent(stats.wc.hit3)))) 189 | .append($("
Hit@10").addClass("success").text(percent(stats.wc.hit10)))) 192 | .append($("
MRR").addClass("success").text(to_fixed(stats.wc.mrr, 3)))) 195 | .append($("
Chars completed (/word)").addClass("warning").text(to_fixed(stats.wc.chars_completed, 3)))) 198 | ); 199 | } 200 | } 201 | 202 | function set_side_by_side(side_by_side) { 203 | if (side_by_side) { 204 | $(".side-by-side").show(); 205 | } else { 206 | $(".side-by-side").hide(); 207 | } 208 | $(".word").tooltip(side_by_side ? "disable" : "enable"); 209 | } 210 | 211 | function render_wr_detail(datum) { 212 | var detail = $("
"); 213 | 214 | detail.append($("

") 215 | .append($("").text("Target: ")) 216 | .append(datum.target) 217 | .append("
") 218 | .append($("").text("Corrupted: ")) 219 | .append(datum.verbatim)); 220 | 221 | var table = $("") 222 | .addClass("table table-hover table-bordered results-table") 223 | .append(""); 224 | 225 | for (var i = 0; i < datum.results.length; ++i) { 226 | var d = datum.results[i]; 227 | var rank = i + 1; 228 | var entry = $("").append($("
Rank:Result:Score:Error score:LM score:
").text(rank)) 229 | .append($("").text(d[0])) 230 | .append($("").text(to_fixed(d[3], 2))) // score 231 | .append($("").text(to_fixed(d[1], 2))) // error score 232 | .append($("").text(to_fixed(d[2], 2))); // lm score 233 | if (d[0] === datum.target) { 234 | entry.addClass("target-row"); 235 | } else if (d[0] === datum.verbatim) { 236 | entry.addClass("verbatim-row"); 237 | } 238 | table.append(entry); 239 | } 240 | detail.append(table); 241 | 242 | $(".detail").empty().append(detail); 243 | } 244 | 245 | function render_wc_detail(datum) { 246 | var table = $("") 247 | .addClass("table table-hover table-bordered results-table"); 248 | 249 | table.append($("") 250 | .append($("") 264 | .append($("
").attr("scope", "col").addClass("table-dark").text('"' + datum.target + '"')) 251 | .append($("").attr("scope", "col").text('#1')) 252 | .append($("").attr("scope", "col").text('#2')) 253 | .append($("").attr("scope", "col").text('#3')) 254 | .append($("").attr("scope", "col").text('Rank'))); 255 | 256 | for (var i = 0; i < datum.target.length; ++i) { 257 | var prefix = datum.target.substr(0, i); 258 | var suffix = datum.target.substr(i); 259 | var completions = (i < datum.completions.length 260 | ? datum.completions[i] 261 | : []); 262 | var rank = 1 + completions.indexOf(suffix); 263 | var entry = $("
").attr("scope", "row").text(prefix)) 265 | .append($("").text(completions[0])) 266 | .append($("").text(completions[1])) 267 | .append($("").text(completions[2])) 268 | .append($("").addClass("results-table-rank").text(rank == 0 ? "null" : rank)); 269 | if ((1 <= rank && rank <= 2) || (i == 0 && rank == 3)) { 270 | entry.addClass("bg-success"); 271 | } else { 272 | entry.addClass("bg-danger"); 273 | } 274 | table.append(entry); 275 | } 276 | 277 | $(".detail").empty().append(table); 278 | } 279 | 280 | function render_pretty(data) { 281 | var root = d3.select(".pretty"); 282 | 283 | var rows = root.selectAll("p") 284 | .data(data_by_line(data)) 285 | .enter() 286 | .append("p") 287 | .classed("line", true); 288 | 289 | var cells = rows.selectAll("div") 290 | .data(function (d) { return d; }) 291 | .enter() 292 | .append("div") 293 | .classed("word", true) 294 | .classed("spaced", function (d) { return d.skip; }); 295 | 296 | // Basic content - the targets themselves 297 | cells.append("div") 298 | .text(function (d) { return d.target; }); 299 | 300 | // Style unselected cells 301 | // N.B. null or true mean unfiltered! 302 | cells.classed("filtered", function(d) { return d.select === false; }); 303 | 304 | // Only return selected cells 305 | // N.B. null or true mean unfiltered! 306 | return cells.filter(function (d) { return d.select !== false; }); 307 | } 308 | 309 | function render_wr_pretty(data) { 310 | var cells = render_pretty(data); 311 | 312 | // On click - details 313 | cells.style("cursor", "pointer") 314 | .on("click", function (d) { 315 | $(".word.selected").removeClass("selected"); 316 | $(this).addClass("selected"); 317 | render_wr_detail(d); 318 | d3.event.stopPropagation(); 319 | }); 320 | 321 | var changed_cells = cells.filter(function (d) { 322 | var r0 = d.results[0][0]; 323 | return d.target !== d.verbatim || d.verbatim !== r0; 324 | }); 325 | 326 | // Tooltips 327 | changed_cells.attr("data-toggle", "tooltip") 328 | .attr("data-html", true) 329 | .attr("data-placement", "bottom") 330 | .attr("title", function (d) { 331 | return "
Target: " + d.target + 332 | "
Corrupted: " + d.verbatim + 333 | "
Prediction: " + d.results[0][0] + 334 | "
"; 335 | }); 336 | 337 | // Exapandable "side-by-side" content 338 | changed_cells.append("div") 339 | .classed("side-by-side", true) 340 | .text(function (d) { return d.verbatim; }); 341 | changed_cells.append("div") 342 | .classed("side-by-side", true) 343 | .text(function (d) { return d.results[0][0]; }); 344 | set_side_by_side($(".wr-side-by-side")[0].checked); 345 | 346 | // Colours 347 | changed_cells.style("color", function (d) { 348 | var r0 = d.results[0][0]; 349 | if (d.target !== d.verbatim && d.target === r0) { 350 | return "#00a030"; // corrected 351 | } else if (d.target !== d.verbatim && d.target !== r0) { 352 | return "#ff8000"; // uncorrected 353 | } else if (d.target === d.verbatim && d.target !== r0) { 354 | return "#ff0000"; // miscorrected 355 | } else { 356 | console.warn("unexpected case - cell should not be selected"); 357 | return "#000000"; 358 | } 359 | }); 360 | } 361 | 362 | function set_entropy_min(minLogp) { 363 | d3.selectAll(".entropy-hit") 364 | .style("color", function (d) { 365 | var maxLogp = 0; 366 | var x = Math.max(0, Math.min(1, (d.logp - minLogp) / (maxLogp - minLogp))); 367 | return d3.hsl(120 * x, 1, 0.4); 368 | }); 369 | } 370 | 371 | function render_entropy_pretty(data) { 372 | var cells = render_pretty(data); 373 | 374 | // Tooltips 375 | cells.attr("data-toggle", "tooltip") 376 | .attr("data-placement", "bottom") 377 | .attr("title", function (d) { 378 | return d.target + " " + to_fixed(d.logp, 3); 379 | }); 380 | 381 | // Colours 382 | cells.classed("entropy-miss", function (d) { return d.logp === null; }); 383 | cells.classed("entropy-hit", function (d) { return d.logp !== null; }); 384 | set_entropy_min(parseFloat($(".entropy-min").val())); 385 | } 386 | 387 | function render_wc_pretty(data) { 388 | var cells = render_pretty(data); 389 | 390 | cells.style("cursor", "pointer") 391 | .on("click", function (d) { 392 | $(".word.selected").removeClass("selected"); 393 | $(this).addClass("selected"); 394 | render_wc_detail(d); 395 | d3.event.stopPropagation(); 396 | }); 397 | 398 | 399 | cells.classed("wc-predicted", function (d) { 400 | var rank = nwp_rank(d); 401 | return 1 <= rank && rank <= 3; 402 | }); 403 | cells.classed("wc-unpredicted", function (d) { 404 | var rank = nwp_rank(d); 405 | return !(1 <= rank && rank <= 3); 406 | }); 407 | cells.html(function (d) { 408 | var rank = nwp_rank(d); 409 | var completed = chars_completed(d); 410 | if ((1 <= rank && rank <= 3) || completed === 0) { 411 | return d.target; 412 | } else { 413 | var offset = d.target.length - completed; 414 | return d.target.substr(0, offset) + "" + d.target.substr(offset) + ""; 415 | } 416 | }); 417 | } 418 | 419 | 420 | // Toplevel initialization functions 421 | 422 | function setup_wc(results) { 423 | $(".entropy-only").hide(); 424 | $(".wr-only").hide(); 425 | render_summary(summary_stats(results)); 426 | render_wc_pretty(results); 427 | } 428 | 429 | function setup_entropy(results, interval) { 430 | $(".entropy-only").show(); 431 | $(".wr-only").hide(); 432 | $(".entropy-min").val(-interval); 433 | render_summary(summary_stats(results)); 434 | render_entropy_pretty(results); 435 | } 436 | 437 | function setup_wr(results, model) { 438 | $(".wr-only").show(); 439 | $(".entropy-only").hide(); 440 | render_summary(summary_stats(results)); 441 | render_wr_pretty(results); 442 | $(".wr-settings").text(model); 443 | } 444 | 445 | 446 | // Event handler setup 447 | 448 | $(function() { 449 | $(".wr-side-by-side").change(function() { 450 | set_side_by_side(this.checked); 451 | }); 452 | 453 | // keyup|mouseup as change() doesn't fire as reliably for number input 454 | $(".entropy-min").bind("keyup mouseup", function() { 455 | set_entropy_min(this.value); 456 | }); 457 | 458 | $(".pretty").click(function () { 459 | $(".word.selected").removeClass("selected"); 460 | $(".detail").empty(); 461 | }); 462 | }); 463 | -------------------------------------------------------------------------------- /lmchallenge/diff.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Compare a pair of LM Challenge log files, which were generated by 5 | different models (or settings) but over the same text. 6 | ''' 7 | 8 | import click 9 | from .core import common 10 | from . import pretty, stats 11 | 12 | 13 | class RenderCompletion: 14 | '''Pretty-print a token to show difference in 15 | next-word-prediction/completion. 16 | 17 | +-------------+-------------+--------------+ 18 | | Baseline | Log | Color | 19 | +=============+=============+==============+ 20 | | Predicted | Predicted | Black (Grey) | 21 | | Unpredicted | Unpredicted | Default | 22 | +-------------+-------------+--------------+ 23 | | Unpredicted | Predicted | Bold Green -| 24 | | Predicted | Unpredicted | Bold Red | 25 | +-------------+-------------+--------------+ 26 | ''' 27 | @staticmethod 28 | def ntyped(target, completions): 29 | return next( 30 | (i 31 | for i, cs in enumerate(completions) 32 | if ((common.rank(cs, target[i:]) or float('inf')) 33 | <= (3 if i == 0 else 2))), 34 | len(target)) 35 | 36 | def __call__(self, datum, out): 37 | base = self.ntyped(datum['target'], datum['baseline']['completions']) 38 | ntyped = self.ntyped(datum['target'], datum['log']['completions']) 39 | n = min(base, ntyped) 40 | out.color(out.BLACK, bold=False) 41 | out.write(datum['target'][:n]) 42 | if base < ntyped: 43 | out.color(out.RED, bold=True) 44 | elif ntyped < base: 45 | out.color(out.GREEN, bold=True) 46 | else: 47 | out.color(out.DEFAULT, bold=False) 48 | out.write(datum['target'][n:]) 49 | 50 | 51 | class RenderEntropy: 52 | '''Pretty-print a token to show entropy difference. 53 | 54 | +--------------------+------------+ 55 | | Entropy difference | Color | 56 | +====================+============+ 57 | | Skip | Blue | 58 | +--------------------+------------+ 59 | | Unknown | Magenta | 60 | | Unknown -> Known | White | 61 | | Known -> Unknown | Bold White | 62 | +--------------------+------------+ 63 | | +i/2 - ... | Bold Green | 64 | | +i/6 - +i/2 | Green | 65 | | -i/6 - +i/6 | Yellow | 66 | | -i/2 - -i/6 | Red | 67 | | ... - -i/2 | Bold Red | 68 | +--------------------+------------+ 69 | ''' 70 | def __init__(self, interval): 71 | self._interval = interval 72 | 73 | def __call__(self, datum, out): 74 | base = datum['baseline']['logp'] 75 | logp = datum['log']['logp'] 76 | if (base, logp) == (None, None): 77 | out.color(out.MAGENTA, bold=False) 78 | elif logp is None or base is None: 79 | out.color(out.WHITE, bold=(logp is None)) 80 | else: 81 | diff = logp - base 82 | x = self._interval / 6 83 | if 3 * x < diff: 84 | out.color(out.GREEN, True) 85 | elif x < diff: 86 | out.color(out.GREEN, False) 87 | elif -x < diff: 88 | out.color(out.YELLOW, False) 89 | elif -3 * x < diff: 90 | out.color(out.RED, False) 91 | else: 92 | out.color(out.RED, True) 93 | out.write(datum['target']) 94 | 95 | 96 | class RenderReranking: 97 | '''Pretty-print a token to show correction difference. 98 | 99 | +--------------+--------------+-------------+ 100 | | Baseline | Log | Color | 101 | +==============+==============+=============+ 102 | | Skip | Blue | 103 | +--------------+--------------+-------------+ 104 | | Unchanged | Black | 105 | | Corrected | Black | 106 | | Uncorrected | Bold Black | 107 | | Miscorrected | Bold Black | 108 | +--------------+--------------+-------------+ 109 | | Miscorrected | Unchanged | Bold Green | 110 | | Unchanged | Miscorrected | Bold Red | 111 | +--------------+--------------+-------------+ 112 | | Uncorrected | Corrected | Green | 113 | | Corrected | Uncorrected | Red | 114 | +--------------+--------------+-------------+ 115 | ''' 116 | def __init__(self, base_model, model): 117 | self._base_model = base_model 118 | self._model = model 119 | 120 | def __call__(self, datum, out): 121 | target = datum['target'] 122 | base_results = datum['baseline']['results'] 123 | pre = pretty.RenderReranking.is_correct( 124 | target, base_results, lambda e, lm: e) 125 | base = pretty.RenderReranking.is_correct( 126 | target, base_results, self._base_model) 127 | post = pretty.RenderReranking.is_correct( 128 | target, datum['log']['results'], self._model) 129 | 130 | # Changed base->post 131 | if (pre, base, post) == (True, True, False): 132 | # unchanged -> miscorrected 133 | out.color(out.RED, bold=True) 134 | elif (pre, base, post) == (False, True, False): 135 | # corrected -> uncorrected 136 | out.color(out.RED, bold=False) 137 | elif (pre, base, post) == (True, False, True): 138 | # miscorrected -> unchanged 139 | out.color(out.GREEN, bold=True) 140 | elif (pre, base, post) == (False, False, True): 141 | # uncorrected -> corrected 142 | out.color(out.GREEN, bold=False) 143 | 144 | # Unchanged base->post 145 | elif (base, post) == (False, False): 146 | # miscorrected/uncorrected 147 | out.color(out.BLACK, bold=True) 148 | elif (base, post) == (True, True): 149 | # corrected/unchanged 150 | out.color(out.BLACK, bold=False) 151 | 152 | else: 153 | assert False, '(should be) unreachable code' 154 | 155 | out.write(datum['target']) 156 | 157 | 158 | # Script 159 | 160 | class ChallengeChoice(common.ChallengeChoice): 161 | '''Select a pretty printing program. 162 | ''' 163 | @staticmethod 164 | def completion(baseline, log, **args): 165 | return pretty.render_ansi( 166 | common.zip_logs(baseline=baseline, log=log), 167 | RenderCompletion()) 168 | 169 | @staticmethod 170 | def entropy(baseline, log, entropy_interval, **args): 171 | return pretty.render_ansi( 172 | common.zip_logs(baseline=baseline, log=log), 173 | RenderEntropy(interval=entropy_interval)) 174 | 175 | @staticmethod 176 | def reranking(baseline, log, **args): 177 | baseline = list(baseline) 178 | log = list(log) 179 | return pretty.render_ansi( 180 | common.zip_logs(baseline=baseline, log=log), 181 | RenderReranking( 182 | base_model=stats.Reranking.build_model(baseline), 183 | model=stats.Reranking.build_model(log))) 184 | 185 | 186 | @click.command() 187 | @click.argument('baseline', type=click.Path(dir_okay=False)) 188 | @click.argument('log', type=click.Path(dir_okay=False)) 189 | @click.option('-v', '--verbose', default=0, count=True, 190 | help='How much human-readable detail to print to STDERR') 191 | @click.option('-c', '--challenge', type=ChallengeChoice(), 192 | default='auto', 193 | help='Select which challenge to view (in the case where there' 194 | ' are multiple challenges in a single log)') 195 | @click.option('-i', '--entropy_interval', default=10.0, 196 | help='Interval to show entropy differences over (should be' 197 | ' positive)') 198 | def cli(baseline, log, verbose, challenge, entropy_interval): 199 | '''Pretty-print a comparison of two result logs from LM Challenge 200 | (using ANSI color codes). 201 | ''' 202 | common.verbosity(verbose) 203 | 204 | for line in challenge( 205 | common.load_jsonlines(baseline), 206 | common.load_jsonlines(log), 207 | entropy_interval=entropy_interval): 208 | print(line) 209 | 210 | 211 | __doc__ += common.shell_docstring(cli, 'lmc diff') 212 | if __name__ == '__main__': 213 | cli() 214 | -------------------------------------------------------------------------------- /lmchallenge/grep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Search through logs for specific token instances (for further processing). 5 | ''' 6 | 7 | import click 8 | import regex 9 | import emoji 10 | import itertools as it 11 | from .core import common 12 | 13 | 14 | def _target_contains(pattern): 15 | xp = regex.compile(pattern) 16 | return lambda datum: xp.search(datum['target']) is not None 17 | 18 | 19 | def _target_does_not_contain(pattern): 20 | xp = regex.compile(pattern) 21 | return lambda datum: xp.search(datum['target']) is None 22 | 23 | 24 | def _key_matches(key, pattern): 25 | xp = regex.compile(pattern) 26 | return lambda datum: xp.fullmatch(str(datum[key])) is not None 27 | 28 | 29 | def parse_pattern(pattern): 30 | '''Parses token filter patterns (regexes, with a few predefined values). 31 | 32 | `pattern` -- `string` -- see `lmchallenge.grep` for the grammar 33 | (essentially a regex). 34 | 35 | `return` -- `predicate(dict)` -- a predicate taking a datum from a log file 36 | ''' 37 | if pattern.startswith('$'): 38 | # predefined target pattern 39 | if pattern == '$nospace': 40 | return _target_does_not_contain(r'\s') 41 | elif pattern == '$alpha': 42 | return _target_contains(r'\p{L}') 43 | elif pattern == '$alphaonly': 44 | return _target_does_not_contain(r'[^\p{L}]') 45 | elif pattern == '$emoji': 46 | return _target_contains(emoji.get_emoji_regexp().pattern) 47 | elif pattern == '$alphaemoji': 48 | return _target_contains(r'(\p{L})|' + 49 | emoji.get_emoji_regexp().pattern) 50 | 51 | # other patterns 52 | m = regex.fullmatch(r'\$(user|message|token|character)=(.+)', 53 | pattern) 54 | if not m: 55 | raise ValueError( 56 | 'Unrecognized special pattern "{}"'.format(pattern)) 57 | return _key_matches(m.group(1), m.group(2)) 58 | 59 | elif pattern.startswith('~'): 60 | # negated target pattern 61 | return _target_does_not_contain(pattern[1:]) 62 | 63 | else: 64 | # target pattern 65 | return _target_contains(pattern) 66 | 67 | 68 | def parse_patterns_all(*patterns): 69 | '''Parse each pattern according to `parse_pattern`, then combine them into 70 | a single predicate, which requires all of them to match. 71 | 72 | See `lmchallenge.grep.parse_pattern`. 73 | ''' 74 | return common.all_predicates(*(parse_pattern(p) for p in patterns)) 75 | 76 | 77 | def select(data, predicate): 78 | '''Return a copy of "data" with the "select" key added to each datum, based 79 | on the outcome of the predicate. 80 | 81 | `data` -- `iterable(dict)` -- LM Challenge log. If a datum already includes 82 | `"select"`, it is combined, `datum["select"] and 83 | predicate(datum)`). 84 | 85 | `predicate` -- `predicate(dict)` -- accepts a datum from `data` 86 | 87 | `return` -- `iterable(dict)` -- LM Challenge log, with the `"select"` key 88 | ''' 89 | for datum in data: 90 | datum = datum.copy() 91 | datum['select'] = common.is_selected(datum) and predicate(datum) 92 | yield datum 93 | 94 | 95 | class Keep(common.ParamChoice): 96 | '''After passing a log through a 'selector', the methods of this class can 97 | remove parts of the log that are unnecessary, for example: 98 | 99 | `"all"` -- keeps everything 100 | 101 | `"message"` -- keeps any message containing a selected token 102 | 103 | `"token"` -- keeps only selected tokens themselves 104 | ''' 105 | name = 'keep' 106 | choices = ['all', 'message', 'token'] 107 | 108 | @staticmethod 109 | def all(log): 110 | '''Keep every token in the log.''' 111 | return log 112 | 113 | @staticmethod 114 | def message(log): 115 | '''Keeps every message in the log which contains a selected token.''' 116 | for _, tokens in it.groupby(log, lambda x: (x['user'], x['message'])): 117 | tokens = list(tokens) # need to pass through twice 118 | if any(map(common.is_selected, tokens)): 119 | yield from tokens 120 | 121 | @staticmethod 122 | def token(log): 123 | '''Keep only selected tokens.''' 124 | return filter(common.is_selected, log) 125 | 126 | 127 | def grep(pattern, data, keep='all', and_patterns=[]): 128 | '''Search for `pattern` in the log `data`, returning a tagged log 129 | which selects part of the original data. 130 | 131 | `pattern` -- `string` -- see `lmchallenge.grep` for the grammar 132 | (essentially a regex). 133 | 134 | `data` -- `iterable(dict)` -- LM Challenge log. If a selection has already 135 | been applied, sub-selects the log satisfying both selections. 136 | 137 | `keep` -- `string` -- either `"all"`, `"message"`, or `"token"`, what 138 | elements of the log to return (e.g. `"all"` returns the whole 139 | log, with a tagged selection, whereas `"token"` only returns 140 | tokens that match the selection). 141 | 142 | `and_patterns` -- `list(string)` -- additional patterns to apply, all of 143 | which must match. 144 | 145 | `return` -- `iterable(dict)` -- LM Challenge log. 146 | ''' 147 | predicate = parse_patterns_all(*([pattern] + and_patterns)) 148 | return getattr(Keep, keep)(select(data, predicate)) 149 | 150 | 151 | # Script 152 | 153 | @click.command() 154 | @click.argument('pattern') 155 | @click.argument('log', nargs=-1, type=click.Path(exists=True, dir_okay=False)) 156 | @click.option('-v', '--verbose', default=0, count=True, 157 | help='How much human-readable detail to print to STDERR.') 158 | @click.option('-k', '--keep', type=Keep(), default='all', 159 | help='After applying the pattern to select tokens, what should' 160 | ' be kept in the log.') 161 | @click.option('-a', '--and', '--and-pattern', multiple=True, 162 | help='Specify additional patterns that must all match.') 163 | def cli(pattern, log, verbose, keep, and_pattern): 164 | '''Search through logs for specific token instances. 165 | 166 | Pattern grammar: 167 | 168 | # 1. Target regex: 169 | 170 | REGEX -- select any token where the target contains a match of the regular 171 | expression (see the Python regex syntax guide for supported regex) 172 | (N.B. REGEX cannot start with "$" or "~") 173 | 174 | Note that the REGEX matches anywhere in the target, so to match only a 175 | whole string, use the start and end markers, i.e. 176 | 177 | "foo" -- matches "foo", "foobar", "only-foo", etc. 178 | 179 | "^foo$" -- matches "foo" but not "foobar", "only-foo", etc. 180 | 181 | 182 | # 2. Negated target regex: 183 | 184 | "~REGEX" -- select any token where the target does not contain REGEX 185 | 186 | e.g. "~\s" matches tokens that do not contain whitespace 187 | 188 | 189 | # 3. Predefined target patterns: 190 | 191 | "$nospace" -- tokens that don't contain whitespace 192 | 193 | "$alpha" -- tokens that contain alphabetic characters 194 | 195 | "$alphaonly" -- tokens that ONLY contain alphabetic characters 196 | 197 | "$emoji" -- tokens that contain emoji 198 | 199 | "$alphaemoji" -- tokens that contain alphabetic characters or emoji 200 | 201 | 202 | 4. User, message, token, character patterns: 203 | 204 | "$user=REGEX" -- select any user matching REGEX completely 205 | 206 | "$message=REGEX" -- select any message number matching REGEX completely 207 | 208 | "$token=REGEX" -- select any token number matching REGEX completely 209 | 210 | "$character=REGEX" -- select any character number matching REGEX completely 211 | 212 | e.g. "$user=MrKim" matches the user "MrKim" but not "MrKim2" 213 | "$message=[0123]" matches the first 4 messages for each user 214 | "$token=0" matches the first token in each message 215 | 216 | ''' 217 | common.verbosity(verbose) 218 | 219 | data = common.load_jsonlines(common.single_log(log)) 220 | predicate = parse_patterns_all(*([pattern] + list(and_pattern))) 221 | common.dump_jsonlines(keep(select(data, predicate))) 222 | 223 | 224 | __doc__ += common.shell_docstring(cli, 'lmc grep') 225 | if __name__ == '__main__': 226 | cli() 227 | -------------------------------------------------------------------------------- /lmchallenge/log.schema: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/schema#", 3 | 4 | "title": "LM Challenge log format", 5 | "description": "Validate any log datum against the LM Challenge log file specification. A log file should consist of JSONlines encoded items that match this schema.", 6 | 7 | "type": "object", 8 | 9 | "properties": { 10 | "user": {"type": ["string", "null"]}, 11 | "character": {"type": "number"}, 12 | "message": {"type": "number"}, 13 | "token": {"type": "number"}, 14 | "target": {"type": "string"}, 15 | "select": {"type": "boolean"}, 16 | 17 | "logp": { 18 | "title": "(Character|Word) Entropy result", 19 | "type": ["number", "null"] 20 | }, 21 | 22 | "completions": { 23 | "title": "Word Completion result", 24 | "type": "array", 25 | "items": { 26 | "type": "array", 27 | "uniqueItems": true, 28 | "items": { 29 | "type": "string" 30 | } 31 | } 32 | }, 33 | 34 | "results": { 35 | "title": "Word Reranking result", 36 | "type": "array", 37 | "items": { 38 | "type": "array", 39 | "minItems": 3, 40 | "additionalItems": false, 41 | "items": [ 42 | { 43 | "title": "Candidate", 44 | "type": "string" 45 | }, 46 | { 47 | "title": "Error score", 48 | "type": "number", 49 | "maximum": 0 50 | }, 51 | { 52 | "title": "LM score", 53 | "type": ["number", "null"] 54 | }, 55 | { 56 | "title": "Combined score (optional)", 57 | "type": "number" 58 | } 59 | ] 60 | } 61 | }, 62 | "verbatim": {"type": "string"} 63 | }, 64 | 65 | "required": [ 66 | "user", "character", "message", "token", "target" 67 | ], 68 | "dependencies": { 69 | "results": ["verbatim"], 70 | "verbatim": ["results"] 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /lmchallenge/pretty.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Pretty-print the model performance from an LM Challenge log file, 5 | in ANSI colour or HTML format. 6 | ''' 7 | 8 | import click 9 | import io 10 | import os 11 | import tempfile 12 | import logging 13 | import urllib.request 14 | import hashlib 15 | import json 16 | import string 17 | import itertools as it 18 | from . import stats 19 | from .core import common 20 | 21 | 22 | # Rendering utilities 23 | 24 | class Renderer: 25 | '''Base class for rendering selected/unselected tokens. 26 | ''' 27 | def __call__(self, datum, out): 28 | '''Called to render a token (which has not been "deselected") to the 29 | AnsiRender instance "out". 30 | 31 | `datum` -- `dict` -- log datum to render 32 | 33 | `out` -- `lmchallenge.core.common.AnsiRender` -- output of rendering 34 | ''' 35 | raise NotImplementedError 36 | 37 | def html_setup(self, data, float_format): 38 | '''The setup command for a standalone HTML page with the given data. 39 | ''' 40 | raise NotImplementedError 41 | 42 | 43 | class RenderCompletion(Renderer): 44 | '''Pretty-print a token to show next-word-prediction/completion. 45 | 46 | If the token is next-word predicted, the entire token is green (and 47 | bold if it is top-prediction). Otherwise, characters that must be typed 48 | before the token is predicted @2 are coloured red, and completed 49 | characters are yellow: 50 | 51 | +---------------------------+--------------+ 52 | | Case | Color | 53 | +===========================+==============+ 54 | | Next word prediction @1 | Bold Green | 55 | | Next word prediction @3 | Green | 56 | | Unpredicted characters @2 | Red | 57 | | Predicted characters @2 | Black (Grey) | 58 | +---------------------------+--------------+ 59 | ''' 60 | def __call__(self, datum, out): 61 | ranks = [common.rank(cs, datum['target'][i:]) or float('inf') 62 | for i, cs in enumerate(datum['completions'])] 63 | 64 | # First check for prediction 65 | if len(ranks) == 0: 66 | # Should not happen 67 | out.default() 68 | out.write(datum['target']) 69 | elif ranks[0] <= 1: 70 | out.color(out.GREEN, True) 71 | out.write(datum['target']) 72 | elif ranks[0] <= 3: 73 | out.color(out.GREEN, False) 74 | out.write(datum['target']) 75 | else: 76 | # Then completion 77 | typed = next((i for i, r in enumerate(ranks) if r <= 2), 78 | len(datum['target'])) 79 | out.color(out.RED, False) 80 | out.write(datum['target'][:typed]) 81 | out.color(out.BLACK, False) 82 | out.write(datum['target'][typed:]) 83 | 84 | def html_setup(self, data, float_format): 85 | return string.Template( 86 | 'setup_wc(${DATA});' 87 | ).substitute(DATA=_json_dumps_min(data, float_format=float_format)) 88 | 89 | 90 | class RenderEntropy(Renderer): 91 | '''Pretty-print a token to show entropy contribution. 92 | 93 | +-------------+------------+ 94 | | Entropy | Color | 95 | +=============+============+ 96 | | Skip | Blue | 97 | +-------------+------------+ 98 | | Unknown | Magenta | 99 | +-------------+------------+ 100 | | 0 - i/5 | Bold Green | 101 | | i/5 - 2i/5 | Green | 102 | | 2i/5 - 3i/5 | Yellow | 103 | | 3i/5 - 4i/5 | Red | 104 | | 4i/5 - ... | Bold Red | 105 | +-------------+------------+ 106 | ''' 107 | def __init__(self, interval): 108 | self._interval = interval 109 | 110 | def __call__(self, datum, out): 111 | logp = datum.get('logp') 112 | x = self._interval / 5 113 | if logp is None: 114 | out.color(out.MAGENTA, False) 115 | elif -logp < x: 116 | out.color(out.GREEN, True) 117 | elif -logp < 2 * x: 118 | out.color(out.GREEN, False) 119 | elif -logp < 3 * x: 120 | out.color(out.YELLOW, False) 121 | elif -logp < 4 * x: 122 | out.color(out.RED, False) 123 | else: 124 | out.color(out.RED, True) 125 | out.write(datum['target']) 126 | 127 | def html_setup(self, data, float_format): 128 | return string.Template( 129 | 'setup_entropy(${DATA}, ${INTERVAL});' 130 | ).substitute( 131 | DATA=_json_dumps_min(data, float_format=float_format), 132 | INTERVAL=self._interval, 133 | ) 134 | 135 | 136 | class RenderReranking(Renderer): 137 | '''Pretty-print a token to show correction. 138 | 139 | +-----------+-----------+---------+ 140 | | Before | After | Color | 141 | +===========+===========+=========+ 142 | | Skip | Blue | 143 | +-----------+-----------+---------+ 144 | | Incorrect | Incorrect | Yellow | 145 | +-----------+-----------+---------+ 146 | | Incorrect | Correct | Green | 147 | +-----------+-----------+---------+ 148 | | Correct | Incorrect | Red | 149 | +-----------+-----------+---------+ 150 | | Correct | Correct | White | 151 | +-----------+-----------+---------+ 152 | ''' 153 | def __init__(self, model): 154 | self._model = model 155 | 156 | @staticmethod 157 | def is_correct(target, results, model): 158 | scores = ((t, model(e, lm)) 159 | for t, e, lm in results) 160 | target_score = next(score for t, score in scores if t == target) 161 | return all(score < target_score 162 | for t, score in scores if t != target) 163 | 164 | def __call__(self, datum, out): 165 | target = datum['target'] 166 | results = datum['results'] 167 | pre = self.is_correct(target, results, lambda e, lm: e) 168 | post = self.is_correct(target, results, self._model) 169 | if (pre, post) == (False, False): 170 | # unchanged incorrect 171 | out.color(out.YELLOW, False) 172 | elif (pre, post) == (False, True): 173 | # corrected 174 | out.color(out.GREEN, False) 175 | elif (pre, post) == (True, False): 176 | # miscorrected 177 | out.color(out.RED, False) 178 | else: 179 | # unchanged correct 180 | out.default() 181 | out.write(datum['target']) 182 | 183 | def html_setup(self, data, float_format): 184 | data = list(_log_combined_score(data, self._model)) 185 | return string.Template( 186 | 'setup_wr(${DATA}, "${MODEL}");' 187 | ).substitute( 188 | DATA=_json_dumps_min(data, float_format=float_format), 189 | MODEL=str(self._model), 190 | ) 191 | 192 | 193 | # HTML rendering 194 | 195 | def _read_data_file(name): 196 | '''Read a file from the bundled 'lmchallenge/data' directory, and return 197 | the contents as a string. 198 | ''' 199 | with open(os.path.join(os.path.dirname(__file__), 'data', name)) as f: 200 | return f.read() 201 | 202 | 203 | def _download_cache_cdn(url, sha_384): 204 | '''Download a file from 'url', which should have the SHA384 matching 205 | 'sha_384' (which should be a hex string). 206 | ''' 207 | root = os.path.join(tempfile.gettempdir(), 'lmc_pretty') 208 | if not os.path.isdir(root): 209 | os.makedirs(root) 210 | 211 | target = os.path.join(root, sha_384) 212 | if not os.path.isfile(target): 213 | logging.info('Downloading %s -> %s', url, target) 214 | urllib.request.urlretrieve(url, target) 215 | with open(target, 'rb') as f: 216 | h = hashlib.sha384() 217 | h.update(f.read()) 218 | if h.hexdigest() != sha_384: 219 | logging.error('Checksum mismatch between %s, %s', 220 | url, target) 221 | raise IOError('Checksum mismatch between %s, %s:' 222 | ' expected %s actual %s', 223 | url, target, h.hexdigest(), sha_384) 224 | 225 | with open(target) as f: 226 | return f.read() 227 | 228 | 229 | def _get_viewer_files(): 230 | '''Returns a dictionary of {KEY: DATA} for all the supplementary js & css 231 | data files needed to render the standalone html page. 232 | ''' 233 | # Note: to get the checksums: 234 | # wget https://URL -O - | sha384sum 235 | return dict( 236 | LMC_VIEWER_HTML=_read_data_file('viewer.html'), 237 | LMC_VIEWER_CSS=_read_data_file('viewer.css'), 238 | LMC_VIEWER_JS=_read_data_file('viewer.js'), 239 | BOOTSTRAP_CSS=_download_cache_cdn( 240 | 'https://' 241 | 'maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css', 242 | 'd6af264c93804b1f23d40bbe6b95835673e2da59057f0c04' 243 | '01af210c3763665a4b7a0c618d5304d5f82358f1a6933b3b' 244 | ), 245 | BOOTSTRAP_JS=_download_cache_cdn( 246 | 'https://' 247 | 'maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js', 248 | 'd2649b24310789a95f9ae04140fe80e10ae9aeae4e55f5b7' 249 | 'ecf451de3e442eac6cb35c95a8eb677a99c754ff5a27bc52' 250 | ), 251 | JQUERY_JS=_download_cache_cdn( 252 | 'https://code.jquery.com/jquery-2.2.4.min.js', 253 | 'ad8fe3bfc98c86a0da6d74a8f940a082a2ad76605f777a82' 254 | 'dbf2afc930cd43a3dc5095dac4ad6d31ea6841d6b8839bc1' 255 | ), 256 | D3_JS=_download_cache_cdn( 257 | 'https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js', 258 | '37c10fd189a5d2337b7b40dc5e567aaedfa2a8a53d0a4e9f' 259 | 'd5943e8f6a6ec5ab6706ae24f44f10eafa81718df82cd6e7' 260 | ), 261 | ) 262 | 263 | 264 | def _json_dumps_min(data, float_format=''): 265 | '''Tiny JSON serializer that supports strings, ints, floats, tuples, lists 266 | and dictionaries. 267 | 268 | Compared to json.dumps, allows a format to specified for floating point 269 | values. 270 | ''' 271 | out = io.StringIO() 272 | 273 | def visit(node): 274 | if node is None: 275 | out.write('null') 276 | elif isinstance(node, str): 277 | out.write(json.dumps(node)) 278 | elif isinstance(node, bool): 279 | out.write('true' if node else 'false') 280 | elif isinstance(node, int): 281 | out.write(str(node)) 282 | elif isinstance(node, float): 283 | out.write(format(node, float_format)) 284 | elif isinstance(node, (tuple, list)): 285 | out.write('[') 286 | for i, x in enumerate(node): 287 | if i != 0: 288 | out.write(',') 289 | visit(x) 290 | out.write(']') 291 | elif isinstance(node, dict): 292 | out.write('{') 293 | for i, k in enumerate(node): 294 | if i != 0: 295 | out.write(',') 296 | visit(k) 297 | out.write(':') 298 | visit(node[k]) 299 | out.write('}') 300 | else: 301 | raise ValueError( 302 | 'Unexpected value for JSON conversion: {}'.format(node)) 303 | visit(data) 304 | return out.getvalue() 305 | 306 | 307 | def _log_combined_score(data, model): 308 | '''Add the combined score to the word reranking results in the log, and 309 | sort descending score. 310 | ''' 311 | for datum in data: 312 | datum = datum.copy() 313 | datum['results'] = list(sorted( 314 | ((candidate, error_score, lm_score, model(error_score, lm_score)) 315 | for candidate, error_score, lm_score in datum['results']), 316 | key=lambda x: -x[-1] 317 | )) 318 | yield datum 319 | 320 | 321 | # Toplevel rendering functions 322 | 323 | def render_ansi(data, renderer): 324 | '''Render an LMC log to a colourized line-based output. 325 | 326 | `data` -- `generator(dict)` -- LM Challenge log 327 | 328 | `renderer` -- `lmchallenge.pretty.Renderer` -- to render the log 329 | 330 | `return` -- `generator(string)` -- generates ANSI-formatted lines 331 | ''' 332 | for _, msg_data in it.groupby( 333 | data, lambda d: (d.get('user'), d.get('message'))): 334 | with io.StringIO() as f: 335 | out = common.AnsiRender(f) 336 | char_n = 0 337 | for datum in msg_data: 338 | if char_n < datum['character']: 339 | out.write(' ') 340 | char_n = datum['character'] + len(datum['target']) 341 | if common.is_selected(datum): 342 | renderer(datum, out) 343 | else: 344 | out.color(out.BLUE, bold=False) 345 | out.write(datum['target']) 346 | out.default() 347 | yield f.getvalue() 348 | 349 | 350 | def render_html(data, renderer, float_format): 351 | '''Render an LMC log to a standalone (and mildly interactive) HTML file. 352 | 353 | `data` -- `generator(dict)` -- LM Challenge log 354 | 355 | `renderer` -- `lmchallenge.pretty.Renderer` -- to render the log 356 | 357 | `float_format` -- `string` -- format string for floating point numbers in 358 | the resulting HTML file's compact JSON log 359 | 360 | `return` -- `string` -- standalone HTML 361 | ''' 362 | # Render the HTML file, with all dependencies inlined 363 | files = _get_viewer_files() 364 | return string.Template(files['LMC_VIEWER_HTML']).substitute( 365 | LMC_SETUP=renderer.html_setup(data, float_format), **files 366 | ) 367 | 368 | 369 | # Script 370 | 371 | class ChallengeChoice(common.ChallengeChoice): 372 | '''Select a pretty printing program. 373 | ''' 374 | @staticmethod 375 | def completion(data, **args): 376 | return RenderCompletion() 377 | 378 | @staticmethod 379 | def entropy(data, entropy_interval, **args): 380 | return RenderEntropy(interval=entropy_interval) 381 | 382 | @staticmethod 383 | def reranking(data, **args): 384 | return RenderReranking(model=stats.Reranking.build_model(data)) 385 | 386 | 387 | class OutputChoice(common.ParamChoice): 388 | '''Select an output format. 389 | ''' 390 | name = 'output' 391 | choices = ['ansi', 'html'] 392 | 393 | @staticmethod 394 | def ansi(data, renderer, **args): 395 | for line in render_ansi(data, renderer): 396 | print(line) 397 | 398 | @staticmethod 399 | def html(data, renderer, float_format, **args): 400 | print(render_html(data, renderer, float_format=float_format)) 401 | 402 | 403 | def ansi(data, entropy_interval=10.0): 404 | '''Render and return an ANSI-formatted pretty-printing of a LM Challenge log. 405 | 406 | `data` -- `iterable(dict)` -- LM Challenge log 407 | 408 | `return` -- `iterable(string)` -- ANSI-coloured rendering of the log 409 | ''' 410 | data = list(data) 411 | return render_ansi( 412 | data, 413 | ChallengeChoice.auto( 414 | data, entropy_interval=entropy_interval)) 415 | 416 | 417 | @click.command() 418 | @click.argument('log', nargs=-1, type=click.Path(dir_okay=False)) 419 | @click.option('-v', '--verbose', default=0, count=True, 420 | help='How much human-readable detail to print to STDERR') 421 | @click.option('-c', '--challenge', type=ChallengeChoice(), 422 | default='auto', 423 | help='Select which challenge to view (in the case where there' 424 | ' are multiple challenges in a single log)') 425 | @click.option('-o', '--output', type=OutputChoice(), 426 | default='ansi', 427 | help='Select whether to use a simple ANSI format, or an' 428 | ' all-in-one html page to show results') 429 | @click.option('-i', '--entropy_interval', default=10.0, 430 | help='Interval to show entropy differences over (should be' 431 | ' positive)') 432 | @click.option('-m', '--float-format', default='.4g', 433 | help='The format of floats in the JSON file (use compact' 434 | ' representations to save file size)') 435 | def cli(log, verbose, challenge, output, entropy_interval, float_format): 436 | '''Pretty-print results from LM Challenge (using ANSI color codes). 437 | ''' 438 | common.verbosity(verbose) 439 | 440 | # list() because of multiple traverse (in the case of reranking) 441 | data = list(common.load_jsonlines(common.single_log(log))) 442 | 443 | renderer = challenge(data, entropy_interval=entropy_interval) 444 | 445 | output(data, renderer, float_format=float_format) 446 | 447 | 448 | __doc__ += common.shell_docstring(cli, 'lmc pretty') 449 | if __name__ == '__main__': 450 | cli() 451 | -------------------------------------------------------------------------------- /lmchallenge/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Run evaluation utilities over a model, to generate LM Challenge logs. 5 | ''' 6 | 7 | import click 8 | import json 9 | import sys 10 | import logging 11 | import random 12 | import itertools as it 13 | import functools as ft 14 | from .core import common, errors, model 15 | 16 | 17 | # Helpers 18 | 19 | def find_tokens(text, tokenizer): 20 | '''Return a generator of `{"token", "character", "target"}` dictionaries by 21 | tokenizing some text. 22 | 23 | `text` -- `string` -- the text to tokenize 24 | 25 | `tokenizer` -- `regex.Regex` -- the tokenizer to use on the text 26 | (supporting `finditer`) 27 | 28 | `return` -- `generator(dict)` -- a generator of tokens, which have keys: 29 | 30 | - `token` -- `int` -- index of the token in the message 31 | - `character` -- `int` -- index of the first character of the token in 32 | the message 33 | - `target` -- `string` -- the text content of the token itself 34 | ''' 35 | for i, m in enumerate(tokenizer.finditer(text)): 36 | yield dict(token=i, character=m.span()[0], target=m.group()) 37 | 38 | 39 | def get_completions(model, context, target): 40 | '''A generator of completions from a model, with successive "typed" prefixes. 41 | 42 | `model` -- `lmchallenge.core.model.Model` -- model to evaluate 43 | 44 | `context` -- `string` -- the text before the target 45 | 46 | `target` -- `string` -- the token being typed 47 | 48 | `return` -- `generator(list(string))` -- generates lists of string 49 | predictions, at each point while typing out `target` 50 | ''' 51 | for i in range(len(target)): 52 | yield [w for w, s in model.predict(context + target[:i], None)] 53 | 54 | 55 | def get_logp(model, context, target): 56 | '''Wrap the model.predict API to query the score of a single target. 57 | ''' 58 | results = list(model.predict(context, [target])) 59 | if 2 <= len(results): 60 | logging.warning('multiple results returned for a single candidate') 61 | try: 62 | return next(s for w, s in results if w == target) 63 | except StopIteration: 64 | return None 65 | 66 | 67 | # Evaluation 68 | 69 | def evaluate_completions(model, context, target, next_word_only): 70 | '''Evaluator for Word Completion. 71 | 72 | `next_word_only` -- `bool` -- save time by only producing results for 73 | next-word-prediction (i.e. no completion) 74 | ''' 75 | return dict(completions=list(it.islice( 76 | get_completions( 77 | model=model, 78 | context=context, 79 | target=target, 80 | ), 81 | 1 if next_word_only else None 82 | ))) 83 | 84 | 85 | def evaluate_entropy(model, context, target): 86 | '''Evaluator for Word/Character Entropy. 87 | ''' 88 | return dict(logp=get_logp(model=model, context=context, target=target)) 89 | 90 | 91 | def evaluate_reranking(model, context, target, 92 | error_config, search, num_candidates, rand): 93 | '''Evaluator for Error Reranking. 94 | 95 | `error_config` -- `dict` -- for `lmchallenge.core.errors` 96 | 97 | `serach` -- `lmchallenge.core.errors.Search` -- find nearby words 98 | 99 | `num_candidates` -- `int` -- number of candidates to search 100 | 101 | `rand` -- `random.Random` -- use for corrupting the text 102 | ''' 103 | corrupted = errors.corrupt(error_config, target, rand=rand) 104 | 105 | candidates = search(corrupted, num_candidates) 106 | 107 | # clip off trailing candidates if needed to ensure corrupted & target 108 | # can be added - this should keep the maximum size of candidates 109 | # clamped at 'n' and ensuring that it contains 'target' and 'corrupted' 110 | candidates = set( 111 | candidates[:(num_candidates - 112 | (target not in candidates) - 113 | (corrupted not in candidates))] 114 | ) | set([corrupted, target]) 115 | 116 | lm_scores = dict(model.predict(context, candidates)) 117 | 118 | results = [ 119 | (candidate, 120 | errors.score(error_config, corrupted, candidate), 121 | lm_scores.get(candidate, None)) 122 | for candidate in candidates 123 | ] 124 | return dict( 125 | verbatim=corrupted, 126 | results=list(sorted(results, key=lambda x: -x[1])) 127 | ) 128 | 129 | 130 | def run_tokens(model, data, train, tokenizer, evaluate): 131 | '''Run a per-token evaluation. 132 | 133 | `model` -- `lmchallenge.core.model.Model` -- to evaluate 134 | 135 | `data` -- `iterable(dict)` -- an iterable of "message" dictionaries, 136 | containing: 137 | 138 | - `text` -- `string` -- the contents of the message 139 | - `user` -- `string` (optional) -- a user ID 140 | 141 | `train` -- `bool` -- should we train after every message? 142 | 143 | `tokenizer` -- `regex.Regex` -- tokenizer for finding tokens in messages 144 | 145 | `evaluate` -- `callable(model, context, target) -> dict` -- runs evaluation 146 | for a single target 147 | 148 | `return` -- `generator(dict)` -- generator of results, containing: 149 | 150 | - `user` -- `string` -- from the message 151 | - `message` -- `int` -- index of message 152 | - `token` -- `int` -- index of token within the message 153 | - `character` -- `int` -- index of character within the message 154 | - `target` -- `string` -- token being typed 155 | ''' 156 | data_by_user = it.groupby(data, lambda x: x.get('user')) 157 | for user, messages in data_by_user: 158 | if train: 159 | model.clear() 160 | for message_n, message in enumerate(messages): 161 | for token in find_tokens(text=message['text'], 162 | tokenizer=tokenizer): 163 | yield dict( 164 | user=user, 165 | message=message_n, 166 | **evaluate( 167 | model=model, 168 | context=message['text'][:token['character']], 169 | target=token['target'] 170 | ), 171 | **token 172 | ) 173 | if train: 174 | model.train(message['text']) 175 | 176 | 177 | def wc(model, data, train=False, next_word_only=False): 178 | '''Run the Word Completion task over `model` and `data` to generate a 179 | result log. 180 | 181 | `model` -- `lmchallenge.core.model.Model` -- to evaluate 182 | 183 | `data` -- `iterable(string)` or `iterable(dict)` -- text data 184 | 185 | `train` -- `bool` -- should the model be trained after each line? 186 | 187 | `next_word_only` -- `bool` -- speed up evaluation by only evaluating 188 | next-word-prediction, not completion 189 | 190 | `return` -- `iterable(dict)` -- LM Challenge log 191 | ''' 192 | return run_tokens( 193 | model, 194 | common.autodetect_input(data), 195 | train=train, 196 | tokenizer=common.WORD_TOKENIZER, 197 | evaluate=ft.partial( 198 | evaluate_completions, 199 | next_word_only=next_word_only 200 | )) 201 | 202 | 203 | def we(model, data, train=False): 204 | '''Run the Word Entropy task over `model` and `data` to generate a 205 | result log. 206 | 207 | Word Entropy scans through the text word-by-word and asks 208 | the model to score each word based only upon previous context. The 209 | model is responsible for returning a normalized log probability as the 210 | score for any context and target queried. Two models may only be compared 211 | if they share the same vocabulary. 212 | 213 | `model` -- `lmchallenge.core.model.Model` -- to evaluate 214 | 215 | `data` -- `iterable(string)` or `iterable(dict)` -- text data 216 | 217 | `train` -- `bool` -- should the model be trained after each line? 218 | 219 | `return` -- `iterable(dict)` -- LM Challenge log 220 | ''' 221 | return run_tokens( 222 | model, 223 | common.autodetect_input(data), 224 | train=train, 225 | tokenizer=common.WORD_TOKENIZER, 226 | evaluate=evaluate_entropy) 227 | 228 | 229 | def ce(model, data, train=False): 230 | '''Run the Character Entropy task over `model` and `data` to generate a 231 | result log. 232 | 233 | Character Entropy scans through the text character-by-character and asks 234 | the model to score each character based only upon previous context. The 235 | model is responsible for returning a normalized log probability as the 236 | score for any context and target queried. Two models may only be compared 237 | if they share the same vocabulary. 238 | 239 | `model` -- `lmchallenge.core.model.Model` -- to evaluate 240 | 241 | `data` -- `iterable(string)` or `iterable(dict)` -- text data 242 | 243 | `train` -- `bool` -- should the model be trained after each line? 244 | 245 | `return` -- `iterable(dict)` -- LM Challenge log 246 | ''' 247 | return run_tokens( 248 | model, 249 | common.autodetect_input(data), 250 | train=train, 251 | tokenizer=common.CHARACTER_TOKENIZER, 252 | evaluate=evaluate_entropy) 253 | 254 | 255 | def wr(model, data, vocab, train=False, seed=42, 256 | num_candidates=100, error_config=errors.DEFAULT_CONFIG): 257 | '''Run the Word Reranking task over `model` and `data` to generate a 258 | result log. 259 | 260 | Word Reranking corrupts the original text word-by-word using a simple 261 | character substitution error model, then looks up nearby candidate words 262 | in a (large) vocabulary of words. Each candidate is paired with the 263 | probability under the error model of generating the corruption (a "perfect" 264 | correction model score), and the task for model is to generate language 265 | model scores for each candidate that can be combined linearly with the 266 | error score to rank the candidates and recover the original text. 267 | 268 | `model` -- `lmchallenge.core.model.Model` -- to evaluate 269 | 270 | `data` -- `iterable(string)` or `iterable(dict)` -- text data 271 | 272 | `vocab` -- `iterable(string)` -- vocabulary of words to corrupt the data to 273 | 274 | `train` -- `bool` -- should the model be trained after each line? 275 | 276 | `seed` -- `int` or `None` -- random seed to use to generate corrupted 277 | candidates 278 | 279 | `num_candidates` -- `int` -- number of corrupted candidates to consider 280 | 281 | `error_config` -- `dict` -- see `lmchallenge.core.errors` - defines the 282 | generation and scoring of error candidates 283 | 284 | `return` -- `iterable(dict)` -- LM Challenge log 285 | ''' 286 | return run_tokens( 287 | model, 288 | common.autodetect_input(data), 289 | train=train, 290 | tokenizer=common.WORD_TOKENIZER, 291 | evaluate=ft.partial( 292 | evaluate_reranking, 293 | error_config=error_config, 294 | search=errors.Search(words=set(vocab)), 295 | num_candidates=num_candidates, 296 | rand=random.Random(seed), 297 | )) 298 | 299 | 300 | # Command line helpers 301 | 302 | class PredictorSpec(click.ParamType): 303 | '''Loads a predictor, either from a Python module or a shell command. 304 | ''' 305 | class PythonModel: 306 | def __init__(self, ctor): 307 | self.ctor = ctor 308 | 309 | def __call__(self, options): 310 | return self.ctor(**options) 311 | 312 | class ShellModel: 313 | def __init__(self, cmd): 314 | self.cmd = cmd 315 | 316 | def __call__(self, options): 317 | return model.ShellModel(self.cmd, options) 318 | 319 | name = 'predictor_spec' 320 | 321 | def get_metavar(self, param): 322 | return 'SPEC' 323 | 324 | def convert(self, value, param, ctx): 325 | if common.is_qualified_name(value): 326 | return self.PythonModel(common.lookup_qualified_name(value)) 327 | else: 328 | return self.ShellModel(value) 329 | 330 | 331 | class InputFormat(common.ParamChoice): 332 | '''Input handling - text or json, 333 | ''' 334 | name = 'input_format' 335 | choices = ['auto', 'text', 'json'] 336 | 337 | @staticmethod 338 | def _is_json(line): 339 | try: 340 | # if you can parse it as JSON 341 | d = json.loads(line) 342 | # and is an object containing a key called "text" 343 | # then assume it is our "json" format 344 | return isinstance(d, dict) and 'text' in d 345 | except json.JSONDecodeError: 346 | return False 347 | 348 | @classmethod 349 | def auto(cls, lines): 350 | first_line, lines = common.peek(lines) 351 | if first_line is None: 352 | pass 353 | elif cls._is_json(first_line): 354 | yield from cls.json(lines) 355 | else: 356 | yield from cls.text(lines) 357 | 358 | @staticmethod 359 | def text(lines): 360 | for line in lines: 361 | yield dict(text=line.rstrip('\r\n')) 362 | 363 | @staticmethod 364 | def json(lines): 365 | for line in lines: 366 | yield json.loads(line) 367 | 368 | 369 | # Command lines 370 | 371 | @click.group() 372 | @click.argument('predictor', type=PredictorSpec()) 373 | @click.option('-v', '--verbose', default=0, count=True, 374 | help='How much human-readable detail to print to STDERR.') 375 | @click.option('-t', '--train/--no-train', default=False, 376 | help='Train the model on lines of text after predictions have' 377 | ' been given for them (and any others with the same timestamp),' 378 | ' resetting for each userId. This is mainly useful for dynamic' 379 | ' modelling experiments with the json corpus format.') 380 | @click.option('-f', '--format', default='auto', type=InputFormat(), 381 | help='Format for test data. text is just lines of plain text;' 382 | ' json is our json-lines corpus format with rows like' 383 | ' {"text": "A line of text", "user": "user1234"},' 384 | ' grouped (e.g. ordered) by userId. Default "auto" is to try to' 385 | ' auto-detect the format, based on the first line of the log.') 386 | @click.option('-o', '--options', type=common.JsonParam(), default='{}', 387 | help='Additional JSON-formatted options to be parsed and' 388 | ' passed to a Python module predictor, or converted to command' 389 | ' line arguments for a shell predictor. In the case of shell' 390 | ' - the arguments are passed as "--key value", unless key' 391 | ' already starts with a hyphen (in which case just "key value"),' 392 | ' and an optional list of arguments with the key "positional".' 393 | ' For example: \'{"abc": 123, "-n" 10,' 394 | ' "positional": ["hello", "world"]}\' will be passed as' 395 | ' ` hello world -n 10 --abc 123`.') 396 | @click.pass_context 397 | def cli(ctx, verbose, predictor, train, format, options): 398 | '''Run a challenge for a predictor over some test text. 399 | Pipe in text to record an evaluation run of a pipeable predictor, on a 400 | language modelling task. 401 | 402 | Analyse the output by piping it into `lmchallenge.stats` or 403 | `lmchallenge.pretty`. 404 | 405 | PREDICTOR - either a shell command to run a pipeable predictor 406 | e.g. "./my-predictor model.lm", 407 | or the qualified name of a Python class or function 408 | e.g. "mymodule.MyClass". 409 | ''' 410 | common.verbosity(verbose) 411 | 412 | def _runner(tokenizer, evaluate): 413 | with predictor(options) as model: 414 | common.dump_jsonlines(run_tokens( 415 | model=model, 416 | data=format(sys.stdin), 417 | train=train, 418 | tokenizer=tokenizer, 419 | evaluate=evaluate, 420 | )) 421 | ctx.obj = ctx.obj or {} 422 | ctx.obj['run'] = _runner 423 | 424 | 425 | @cli.command('wc') 426 | @click.option('-p', '--next-word-only/--no-next-word-only', 427 | default=False, 428 | help='Only compute next-word-predictions - don\'t produce' 429 | ' results for prefix completions (for performance)') 430 | @click.pass_context 431 | def cli_wc(ctx, next_word_only): 432 | '''Word Completion Challenge (next-word prediction & completion). 433 | ''' 434 | ctx.obj['run']( 435 | tokenizer=common.WORD_TOKENIZER, 436 | evaluate=ft.partial( 437 | evaluate_completions, 438 | next_word_only=next_word_only) 439 | ) 440 | 441 | 442 | @cli.command('we') 443 | @click.pass_context 444 | def cli_we(ctx): 445 | '''Word Entropy Challenge. 446 | ''' 447 | ctx.obj['run']( 448 | tokenizer=common.WORD_TOKENIZER, 449 | evaluate=evaluate_entropy, 450 | ) 451 | 452 | 453 | @cli.command('wr') 454 | @click.argument('vocab', type=click.File('r')) 455 | @click.option('-r', '--seed', default=42, 456 | help='Random seed to fix (default: fixed), or 0 to get a' 457 | ' pseudorandom seed from the clock') 458 | @click.option('-n', '--num-candidates', default=100, 459 | help='Number of candidates to consider for each word.') 460 | @click.option('--error-chars', default=errors.DEFAULT_CONFIG['error_chars'], 461 | help='Set of characters to sample from when adding' 462 | ' global character errors') 463 | @click.option('-p', '--p-anykey', default=errors.DEFAULT_CONFIG['p_anykey'], 464 | help='Probability of substituting an \'anykey\' error') 465 | @click.pass_context 466 | def cli_wr(ctx, vocab, seed, num_candidates, error_chars, p_anykey): 467 | '''Word Reranking Challenge. 468 | ''' 469 | rand = random.Random(seed) 470 | words = set(line.rstrip('\r\n') for line in vocab) 471 | error_config = dict(error_chars=error_chars, p_anykey=p_anykey) 472 | search = errors.Search(words=words) 473 | ctx.obj['run']( 474 | tokenizer=common.WORD_TOKENIZER, 475 | evaluate=ft.partial( 476 | evaluate_reranking, 477 | error_config=error_config, 478 | search=search, 479 | num_candidates=num_candidates, 480 | rand=rand, 481 | ), 482 | ) 483 | 484 | 485 | @cli.command('ce') 486 | @click.pass_context 487 | def cli_ce(ctx): 488 | '''Character Entropy Challenge. 489 | ''' 490 | ctx.obj['run']( 491 | tokenizer=common.CHARACTER_TOKENIZER, 492 | evaluate=evaluate_entropy, 493 | ) 494 | 495 | 496 | __doc__ += common.shell_docstring(cli, 'lmc run') 497 | __doc__ += common.shell_docstring(cli_wc, 'lmc run wc') 498 | __doc__ += common.shell_docstring(cli_we, 'lmc run we') 499 | __doc__ += common.shell_docstring(cli_wr, 'lmc run wr') 500 | __doc__ += common.shell_docstring(cli_ce, 'lmc run ce') 501 | if __name__ == '__main__': 502 | cli() 503 | -------------------------------------------------------------------------------- /lmchallenge/stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Aggregate results from challenges, and calculate summary statistics. 5 | ''' 6 | 7 | import click 8 | import json 9 | import csv as csvlib 10 | import io 11 | import sys 12 | import struct 13 | import hashlib 14 | import itertools as it 15 | from .core import common 16 | 17 | 18 | class Accumulator: 19 | '''Abstract base class for "accumulation" operations that can be performed 20 | over sequences of events. 21 | ''' 22 | @classmethod 23 | def create(cls): 24 | '''Create an 'empty' accumulator. 25 | ''' 26 | return cls() 27 | 28 | def update(self, datum): 29 | '''Feed a datum into the accumulator, updating the internal state. 30 | 31 | `datum` -- `dict` -- single log datum 32 | ''' 33 | raise NotImplementedError 34 | 35 | @property 36 | def state(self): 37 | '''Get the accumulator's current state. 38 | 39 | `return` -- `object` -- current state (snapshot) of the accumulator 40 | ''' 41 | raise NotImplementedError 42 | 43 | 44 | class Counter(Accumulator): 45 | '''Count the number of events. 46 | ''' 47 | def __init__(self): 48 | self._count = 0 49 | 50 | def update(self, datum): 51 | self._count += 1 52 | 53 | @property 54 | def state(self): 55 | return self._count 56 | 57 | 58 | class UniqueCounter(Accumulator): 59 | '''Count the number of unique datums according to some subclass-defined 60 | keying function. 61 | Note that this only removes consecutive duplicates, e.g. 62 | A A A B B B C C A -- counts as 4 (A B C A) 63 | ''' 64 | def __init__(self): 65 | self._count = 0 66 | self._previous = object() 67 | 68 | @staticmethod 69 | def _key(datum): 70 | '''The key to use to detect consecutive duplicates. 71 | ''' 72 | raise NotImplementedError 73 | 74 | def update(self, datum): 75 | current = self._key(datum) 76 | if current != self._previous: 77 | self._previous = current 78 | self._count += 1 79 | 80 | @property 81 | def state(self): 82 | return self._count 83 | 84 | 85 | class UserCounter(UniqueCounter): 86 | '''Count the number of users. 87 | ''' 88 | @staticmethod 89 | def _key(datum): 90 | return datum.get('user') 91 | 92 | 93 | class MessageCounter(UniqueCounter): 94 | '''Count the total number of messages. 95 | ''' 96 | @staticmethod 97 | def _key(datum): 98 | return (datum.get('user'), datum.get('message')) 99 | 100 | 101 | class CharacterCounter(Accumulator): 102 | def __init__(self): 103 | self._sum = 0 104 | 105 | def update(self, datum): 106 | self._sum += len(datum['target']) 107 | 108 | @property 109 | def state(self): 110 | return self._sum 111 | 112 | 113 | class Hash: 114 | '''A stable hash, for computing fingerprints. 115 | ''' 116 | @staticmethod 117 | def get(*columns): 118 | '''Generate a hash value from an ordered list of values. 119 | Supports only: integers, booleans, strings, None. 120 | ''' 121 | h = 0 122 | for column in columns: 123 | # Arbitrary large prime (for combining hashes) 124 | h *= 48677 125 | if isinstance(column, str): 126 | # Use MD5 to generate a 8-byte int hash from the string 127 | h += struct.unpack( 128 | 'value) pair 174 | self._functions = functions 175 | # These values line up with those of _functions 176 | self._values = [0 for _ in functions] 177 | 178 | @classmethod 179 | def create(cls, *hit_ranks): 180 | # Split this into a function to avoid capturing 'n' lexically 181 | # in the generator below, which would cause all comparators to be 182 | # the same 183 | def comparator(n): 184 | return lambda rank: rank <= n 185 | 186 | return cls( 187 | functions=([('srr', lambda rank: 1 / rank), 188 | ('hit', lambda rank: 1)] 189 | + [('hit{}'.format(n), comparator(n)) 190 | for n in hit_ranks]) 191 | ) 192 | 193 | def update(self, datum): 194 | completions = datum.get('completions') 195 | if completions is not None: 196 | rank = common.rank(completions[0], datum['target']) 197 | if rank is not None: 198 | for idx, (name, fn) in enumerate(self._functions): 199 | self._values[idx] += fn(rank) 200 | 201 | @property 202 | def state(self): 203 | return {name: value 204 | for (name, _), value in zip(self._functions, self._values)} 205 | 206 | 207 | class Completion(Accumulator): 208 | '''Compute word completion stats. 209 | ''' 210 | def __init__(self, max_rank): 211 | self._max_rank = max_rank 212 | self._characters = 0 213 | self._tokens = 0 214 | 215 | @classmethod 216 | def create(cls, max_rank): 217 | return cls(max_rank=max_rank) 218 | 219 | def update(self, datum): 220 | all_completions = datum.get('completions') 221 | if all_completions is not None: 222 | target = datum['target'] 223 | for start, completions in enumerate(all_completions): 224 | if common.rank(completions, 225 | target[start:], 226 | max_rank=self._max_rank): 227 | self._characters += len(datum['target']) - start 228 | self._tokens += 1 229 | break 230 | 231 | @property 232 | def state(self): 233 | return dict( 234 | characters=self._characters, 235 | tokens=self._tokens 236 | ) 237 | 238 | 239 | class Entropy(Accumulator): 240 | '''Accumulate cross-entropy stats. 241 | Includes a custom fingerprint (different from the toplevel fingerprint, 242 | as the rules for comparing entropy values are stricter than the rules 243 | for comparing prediction/completion/reranking. 244 | ''' 245 | def __init__(self): 246 | self._sum = 0 247 | self._tokens = 0 248 | self._fingerprint = Fingerprint.create() 249 | 250 | def update(self, datum): 251 | logp = datum.get('logp') 252 | if logp is not None: 253 | self._sum -= logp 254 | self._tokens += 1 255 | self._fingerprint.update(datum) 256 | 257 | @property 258 | def state(self): 259 | return dict( 260 | sum=self._sum, 261 | tokens=self._tokens, 262 | fingerprint=self._fingerprint.state 263 | ) 264 | 265 | 266 | class Reranking(Accumulator): 267 | '''Optimize a reranking function, by loading all events into 268 | memory, and using scipy. 269 | ''' 270 | def __init__(self): 271 | self._scores_error = [] 272 | self._scores_lm = [] 273 | 274 | def update(self, datum): 275 | results = datum.get('results') 276 | if results is not None: 277 | target = datum['target'] 278 | self._scores_error.append(list(it.chain( 279 | (e for t, e, lm in results if t == target), 280 | (e for t, e, lm in results if t != target), 281 | ))) 282 | self._scores_lm.append(list(it.chain( 283 | (lm for t, e, lm in results if t == target), 284 | (lm for t, e, lm in results if t != target), 285 | ))) 286 | 287 | def finalize(self): 288 | '''Finalize the matrices & compute the optimal model. 289 | ''' 290 | from .core import reranking as R 291 | max_candidates = max(len(e) for e in self._scores_error) 292 | self._error = R.jagged_matrix(self._scores_error, max_candidates) 293 | self._lm = R.jagged_matrix(self._scores_lm, max_candidates) 294 | self._model = R.InterpolationRerankingModel.optimize( 295 | error=self._error, 296 | lm=self._lm 297 | ) 298 | 299 | @property 300 | def state(self): 301 | if len(self._scores_error) == 0: 302 | # Take an explicit branch to avoid the "import core.reranking" 303 | # unless the reranking model is used (to keep the numpy/scipy 304 | # dependency optional 305 | return dict( 306 | max_candidates=0, 307 | already_correct=0, 308 | correct=0, 309 | args=None, 310 | ) 311 | else: 312 | self.finalize() 313 | from .core import reranking as R 314 | return dict( 315 | max_candidates=self._error.shape[1], 316 | base_correct=R.count_correct(self._error), 317 | correct=R.count_correct(self._model(self._error, self._lm)), 318 | args=self._model.args, 319 | ) 320 | 321 | @classmethod 322 | def build_model(cls, data): 323 | '''A helper to build a reranking model from data, using this class. 324 | ''' 325 | a = cls.create() 326 | for datum in data: 327 | if common.is_selected(datum): 328 | a.update(datum) 329 | a.finalize() 330 | return a._model 331 | 332 | 333 | class Composite(Accumulator): 334 | '''Combine accumulators, providing results as a dictionary. 335 | ''' 336 | def __init__(self, children): 337 | self._children = children 338 | 339 | @classmethod 340 | def create(cls, **children): 341 | return cls(children) 342 | 343 | def update(self, datum): 344 | for child in self._children.values(): 345 | child.update(datum) 346 | 347 | @property 348 | def state(self): 349 | return {name: accumulator.state 350 | for name, accumulator in self._children.items()} 351 | 352 | 353 | class Stats(Composite): 354 | '''A standard set of useful LM Challenge stats. 355 | ''' 356 | @classmethod 357 | def create(cls): 358 | return super(Stats, cls).create( 359 | users=UserCounter.create(), 360 | messages=MessageCounter.create(), 361 | tokens=Counter.create(), 362 | characters=CharacterCounter.create(), 363 | fingerprint=Fingerprint.create(), 364 | prediction=NextWordPrediction.create( 365 | 1, 3, 10, 20 366 | ), 367 | completion=Completion.create( 368 | max_rank=2 369 | ), 370 | entropy=Entropy.create(), 371 | reranking=Reranking.create(), 372 | ) 373 | 374 | 375 | class Selection(Accumulator): 376 | '''Select data based on the "select" tag in each datum. 377 | ''' 378 | def __init__(self, child): 379 | self._skipped = 0 380 | self._child = child 381 | 382 | @classmethod 383 | def create(cls, child): 384 | return cls(child=child) 385 | 386 | def update(self, datum): 387 | if common.is_selected(datum): 388 | self._child.update(datum) 389 | else: 390 | self._skipped += 1 391 | 392 | @property 393 | def state(self): 394 | child_state = self._child.state 395 | return dict(skipped=self._skipped, 396 | **(child_state 397 | if isinstance(child_state, dict) else 398 | dict(value=child_state))) 399 | 400 | 401 | def humanize(stats): 402 | '''To be used with the output of the Selection & Stats accumulators. 403 | 404 | `stats` -- `dict` -- as returned by a `lmchallenge.stats.Selection` 405 | of `lmchallenge.stats.Stats` accumulator `.state` 406 | 407 | `return` -- `dict` -- human-readable staistics 408 | ''' 409 | stats = stats.copy() 410 | r = dict() 411 | 412 | # General info 413 | r['fingerprint'] = Hash.format(stats.pop('fingerprint')) 414 | r['users'] = stats.pop('users') 415 | tokens = stats.pop('tokens') 416 | characters = stats.pop('characters') 417 | r['messages_per_user'] = stats.pop('messages') / r['users'] 418 | r['tokens_per_user'] = tokens / r['users'] 419 | r['characters_per_token'] = characters / tokens 420 | 421 | # Skipped/unselected 422 | if 'skipped' in stats: 423 | skipped = stats.pop('skipped') 424 | r['skipped'] = skipped / (skipped + tokens) 425 | 426 | # NextWordPrediction 427 | prediction = stats.pop('prediction') 428 | if prediction['hit'] != 0: 429 | r['prediction'] = { 430 | ('mrr' if k == 'srr' else k): v / tokens 431 | for k, v in prediction.items() 432 | } 433 | # Since we're iterating through all keys, there is no need to check 434 | # that all are accounted for (unlike Completion, Entropy, top-level) 435 | 436 | # Completion 437 | completion = stats.pop('completion').copy() 438 | if completion['tokens'] != 0: 439 | r['completion'] = dict( 440 | characters=completion.pop('characters') / characters, 441 | tokens=completion.pop('tokens') / tokens, 442 | ) 443 | assert len(completion) == 0, 'Unexpected Completion result keys' 444 | 445 | # Entropy 446 | entropy = stats.pop('entropy').copy() 447 | if entropy['tokens'] != 0: 448 | entropy_tokens = entropy.pop('tokens') 449 | r['entropy'] = dict( 450 | fingerprint=Hash.format(entropy.pop('fingerprint')), 451 | hit=entropy_tokens / tokens, 452 | mean=entropy.pop('sum') / entropy_tokens 453 | ) 454 | assert len(entropy) == 0, 'Unexpected Entropy result keys' 455 | 456 | # Reranking 457 | reranking = stats.pop('reranking').copy() 458 | if reranking['max_candidates'] != 0: 459 | reranking.pop('args') # rarely used information 460 | r['reranking'] = dict( 461 | accuracy=reranking.pop('correct') / tokens, 462 | base_accuracy=reranking.pop('base_correct') / tokens, 463 | max_candidates=reranking.pop('max_candidates'), 464 | ) 465 | assert len(reranking) == 0, 'Unexpected Reranking result keys' 466 | 467 | assert len(stats) == 0, 'Unexpected Stats result keys' 468 | return r 469 | 470 | 471 | def stats(data, human=True): 472 | '''Run the standard set of accumulators over 'data'. 473 | 474 | `data` -- `iterable(dict)` -- LM Challenge log 475 | 476 | `human` -- `bool` -- show human-friendly derivative stats 477 | (instead of machine-friendly sums) 478 | 479 | `return` -- `dict` -- an accumulated dictionary of stats 480 | ''' 481 | accumulator = Selection.create(child=Stats.create()) 482 | for datum in data: 483 | accumulator.update(datum) 484 | return humanize(accumulator.state) if human else accumulator.state 485 | 486 | 487 | class Output(common.ParamChoice): 488 | '''How to print the stats output. 489 | ''' 490 | name = 'output_format' 491 | choices = ['json', 'csv'] 492 | 493 | @staticmethod 494 | def json(data): 495 | '''Dump a results set in jsonlines format.''' 496 | out = io.StringIO() 497 | for row in data: 498 | json.dump(row, out, sort_keys=True) 499 | out.write('\n') 500 | return out.getvalue() 501 | 502 | @staticmethod 503 | def csv(data): 504 | '''Dump a dictionary in csv format.''' 505 | out = io.StringIO() 506 | keys = list(common.sort_with_override( 507 | common.flatten_keys(data[0]).keys(), 508 | 'log', 509 | 'fingerprint', 510 | 'users', 511 | 'messages_per_user', 512 | 'tokens_per_user', 513 | 'characters_per_token', 514 | 'skipped', 515 | )) 516 | writer = csvlib.DictWriter(out, fieldnames=keys) 517 | writer.writeheader() 518 | for row in data: 519 | writer.writerow(common.flatten_keys(row)) 520 | return out.getvalue() 521 | 522 | 523 | # Script 524 | 525 | @click.command() 526 | @click.argument('log', nargs=-1, type=click.Path(exists=True, dir_okay=False)) 527 | @click.option('-v', '--verbose', default=0, count=True, 528 | help='How much human-readable detail to print to STDERR.') 529 | @click.option('-n', '--lines', type=click.INT, 530 | help='Limit input to this number of lines') 531 | @click.option('-o', '--output', type=Output(), default='json', 532 | help='Output format.') 533 | @click.option('-h/-H', '--human/--no-human', default=True, 534 | help='Humanize the output.') 535 | def cli(log, verbose, lines, output, human, **args): 536 | '''Extract summary stats from any of the challenge log files. 537 | Specify one or more similar log files, or pipe in results. 538 | ''' 539 | common.verbosity(verbose) 540 | log = log or ['-'] 541 | results = [dict(log=file, 542 | **stats(it.islice(common.load_jsonlines(file), lines), 543 | human=human)) 544 | for file in log] 545 | sys.stdout.write(output(results)) 546 | 547 | 548 | __doc__ += common.shell_docstring(cli, 'lmc stats') 549 | if __name__ == '__main__': 550 | cli() 551 | -------------------------------------------------------------------------------- /lmchallenge/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /lmchallenge/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import pytest 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption('--run-slow', action='store_true', 9 | default=False, help='run slow tests') 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if not config.getoption('--run-slow'): 14 | skip = pytest.mark.skip(reason='only runs with --run-slow') 15 | for item in items: 16 | if 'slow' in item.keywords: 17 | item.add_marker(skip) 18 | -------------------------------------------------------------------------------- /lmchallenge/tests/eg_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Models that don't correctly implement the APIs. 5 | ''' 6 | 7 | from lmchallenge import Model, FilteringWordModel 8 | import math 9 | 10 | 11 | # "Good" models that obey the API 12 | 13 | class SimpleModel(Model): 14 | def predict(self, context, candidates): 15 | return filter( 16 | lambda x: candidates is None or x[0] in candidates, 17 | [('the', math.log(0.5)), 18 | ('a', math.log(0.25)), 19 | ('cat', math.log(0.25))]) 20 | 21 | 22 | class SimpleCharModel(Model): 23 | def predict(self, context, candidates): 24 | return filter( 25 | lambda x: candidates is None or x[0] in candidates, 26 | [('e', math.log(0.5)), 27 | ('s', math.log(0.25)), 28 | ('\t', math.log(0.125)), 29 | ('\n', math.log(0.125))]) 30 | 31 | 32 | class SimpleWordModel(FilteringWordModel): 33 | def score_word(self, context, candidates): 34 | # score according to length 35 | return [(c, -len(c)) for c in candidates] 36 | 37 | def predict_word_iter(self, context): 38 | # return the words from context as predictions, in reverse order 39 | return [(w, -n) for n, w in enumerate(context[::-1])] 40 | 41 | 42 | class DynamicWordModel(FilteringWordModel): 43 | '''A unigram counting dynamic word model. 44 | ''' 45 | def __init__(self): 46 | self._words = {} 47 | self._total = 0 48 | 49 | def score_word(self, context, candidates): 50 | for candidate in candidates: 51 | if candidate in self._words: 52 | yield (candidate, 53 | math.log(self._words[candidate] / self._total)) 54 | 55 | def predict_word_iter(self, context): 56 | words = sorted(self._words.keys(), key=lambda w: -self._words[w]) 57 | return ((w, -n) for n, w in enumerate(words)) 58 | 59 | def train_word(self, text): 60 | for word in text: 61 | self._words[word] = self._words.get(word, 0) + 1 62 | self._total += len(text) 63 | 64 | 65 | # "Bad" models that don't obey the API 66 | 67 | class NotImplementedModel(Model): 68 | def predictions(self, context, candidates): 69 | # wrong name 70 | return [] 71 | 72 | 73 | class WrongArgumentsModel(Model): 74 | def predict(self, context): 75 | # missing candidates 76 | return [] 77 | 78 | 79 | class WrongResultModel(Model): 80 | def predict(self, context, candidates): 81 | # missing result scores 82 | return ['a', 'b', 'c'] 83 | -------------------------------------------------------------------------------- /lmchallenge/tests/test_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from . import eg_models 5 | import lmchallenge as lmc 6 | import math 7 | 8 | 9 | def expect_close(expected, actual): 10 | if expected is None: 11 | assert actual is None 12 | else: 13 | assert abs(expected - actual) < 1e-8 14 | 15 | 16 | def expect_partial_match(expected, actual): 17 | for k in expected: 18 | if isinstance(expected[k], float): 19 | expect_close(expected[k], actual[k]) 20 | else: 21 | assert expected[k] == actual[k] 22 | 23 | 24 | def expect_meta(expected, actual): 25 | for x, y in zip(expected, actual): 26 | expect_partial_match(x, y) 27 | assert len(expected) == len(actual) 28 | 29 | 30 | def expect_logp(expected, actual_log): 31 | for logp, actual in zip(expected, actual_log): 32 | expect_close(logp, actual['logp']) 33 | assert len(expected) == len(actual_log) 34 | 35 | 36 | def expect_completions(expected, actual_log): 37 | for expected_completions, actual in zip(expected, actual_log): 38 | assert expected_completions == actual['completions'] 39 | assert len(expected) == len(actual_log) 40 | 41 | 42 | def expect_results(expected, actual_log): 43 | for expected_results, actual in zip(expected, actual_log): 44 | remaining = {word: lm_score 45 | for word, _, lm_score in actual['results']} 46 | for word, score in expected_results: 47 | expect_close(score, remaining.pop(word)) 48 | assert all(score is None for score in remaining.values()) 49 | assert len(expected) == len(actual_log) 50 | 51 | 52 | class PlainData: 53 | DATA = [dict(text=line) for line in [ 54 | 'the cat ate a cat', 55 | 'The cat', 56 | ]] 57 | META = [ 58 | dict(user=None, message=0, token=0, character=0, target='the'), 59 | dict(user=None, message=0, token=1, character=4, target='cat'), 60 | dict(user=None, message=0, token=2, character=8, target='ate'), 61 | dict(user=None, message=0, token=3, character=12, target='a'), 62 | dict(user=None, message=0, token=4, character=14, target='cat'), 63 | 64 | dict(user=None, message=1, token=0, character=0, target='The'), 65 | dict(user=None, message=1, token=1, character=4, target='cat'), 66 | ] 67 | STATS = dict( 68 | users=1, 69 | messages=2, 70 | tokens=7, 71 | characters=19, 72 | skipped=0, 73 | ) 74 | HUMAN_STATS = dict( 75 | users=1, 76 | messages_per_user=2.0, 77 | tokens_per_user=7.0, 78 | characters_per_token=19/7, 79 | skipped=0.0, 80 | ) 81 | 82 | 83 | class PlainCharData: 84 | DATA = [dict(text=line) for line in [ 85 | 'yes \t!', 86 | '#', 87 | '\n', 88 | ]] 89 | META = [ 90 | dict(user=None, message=0, token=0, character=0, target='y'), 91 | dict(user=None, message=0, token=1, character=1, target='e'), 92 | dict(user=None, message=0, token=2, character=2, target='s'), 93 | dict(user=None, message=0, token=3, character=3, target=' '), 94 | dict(user=None, message=0, token=4, character=4, target='\t'), 95 | dict(user=None, message=0, token=5, character=5, target='!'), 96 | 97 | dict(user=None, message=1, token=0, character=0, target='#'), 98 | 99 | dict(user=None, message=2, token=0, character=0, target='\n'), 100 | ] 101 | STATS = dict( 102 | users=1, 103 | messages=3, 104 | tokens=8, 105 | characters=8, 106 | skipped=0, 107 | ) 108 | HUMAN_STATS = dict( 109 | users=1, 110 | messages_per_user=3.0, 111 | tokens_per_user=8.0, 112 | characters_per_token=1.0, 113 | skipped=0.0, 114 | ) 115 | 116 | 117 | def test_simple_model_we(): 118 | # run 119 | 120 | results = list(lmc.we(eg_models.SimpleModel(), PlainData.DATA)) 121 | expect_meta(PlainData.META, results) 122 | 123 | expect_logp( 124 | [None if p is None else math.log(p) 125 | for p in [0.5, 0.25, None, 0.25, 0.25, None, 0.25]], 126 | results) 127 | 128 | # stats 129 | 130 | stats = lmc.stats.stats(results, human=False) 131 | expect_partial_match(PlainData.STATS, stats) 132 | expect_partial_match( 133 | dict(tokens=5, 134 | sum=(4 * -math.log(0.25) + -math.log(0.5))), 135 | stats['entropy']) 136 | 137 | human_stats = lmc.stats.stats(results, human=True) 138 | expect_partial_match(PlainData.HUMAN_STATS, human_stats) 139 | expect_partial_match( 140 | dict(hit=5/7, 141 | mean=(4/5 * -math.log(0.25) + 1/5 * -math.log(0.5))), 142 | human_stats['entropy']) 143 | 144 | 145 | def test_simple_model_ce(): 146 | # run 147 | 148 | results = list(lmc.ce(eg_models.SimpleCharModel(), PlainCharData.DATA)) 149 | expect_meta(PlainCharData.META, results) 150 | expect_logp( 151 | [None if p is None else math.log(p) 152 | for p in [None, 0.5, 0.25, None, 0.125, None, 153 | None, 154 | 0.125]], 155 | results) 156 | 157 | # stats 158 | 159 | stats = lmc.stats.stats(results, human=False) 160 | expect_partial_match(PlainCharData.STATS, stats) 161 | expect_partial_match( 162 | dict(tokens=4, 163 | sum=(-math.log(0.5) + -math.log(0.25) + 2 * -math.log(0.125))), 164 | stats['entropy']) 165 | 166 | human_stats = lmc.stats.stats(results, human=True) 167 | expect_partial_match(PlainCharData.HUMAN_STATS, human_stats) 168 | expect_partial_match( 169 | dict(hit=0.5, 170 | mean=(1/4 * -math.log(0.5) + 171 | 1/4 * -math.log(0.25) + 172 | 1/2 * -math.log(0.125))), 173 | human_stats['entropy']) 174 | 175 | 176 | def test_simple_model_wc(): 177 | # run 178 | 179 | results = list(lmc.wc(eg_models.SimpleModel(), PlainData.DATA)) 180 | expect_meta(PlainData.META, results) 181 | 182 | w = ['the', 'a', 'cat'] 183 | expect_completions( 184 | [[w, w, w], [w, w, w], [w, w, w], [w], [w, w, w], 185 | [w, w, w], [w, w, w]], 186 | results) 187 | 188 | # stats 189 | 190 | stats = lmc.stats.stats(results, human=False) 191 | expect_partial_match(PlainData.STATS, stats) 192 | expect_partial_match( 193 | dict( 194 | hit1=1, # the 195 | hit3=5, # the, a, (3*)cat 196 | hit10=5, 197 | hit20=5, 198 | hit=5, 199 | srr=(1.0 + 1/2 + 3 * 1/3), 200 | ), stats['prediction']) 201 | expect_partial_match( 202 | dict( 203 | tokens=2, # the, a 204 | characters=4, 205 | ), stats['completion']) 206 | 207 | human_stats = lmc.stats.stats(results, human=True) 208 | expect_partial_match(PlainData.HUMAN_STATS, human_stats) 209 | expect_partial_match( 210 | dict( 211 | hit1=1/7, # the 212 | hit3=5/7, # the, a, (3*)cat 213 | hit10=5/7, 214 | hit20=5/7, 215 | hit=5/7, 216 | mrr=(1.0 + 1/2 + 3 * 1/3) / 7, 217 | ), human_stats['prediction']) 218 | expect_partial_match( 219 | dict( 220 | tokens=2/7, # the, a 221 | characters=4/19, 222 | ), human_stats['completion']) 223 | 224 | 225 | def test_simple_model_wr(): 226 | # run 227 | 228 | results = list(lmc.wr( 229 | eg_models.SimpleModel(), PlainData.DATA, 230 | ['the', 'cat', 'fat'])) 231 | expect_meta(PlainData.META, results) 232 | 233 | result_the = ('the', math.log(0.5)) 234 | result_cat = ('cat', math.log(0.25)) 235 | result_a = ('a', math.log(0.25)) 236 | result_ate = ('ate', None) 237 | result_The = ('The', None) 238 | result_fat = ('fat', None) 239 | expect_results( 240 | [[result_the, result_cat, result_fat], 241 | [result_cat, result_the, result_fat], 242 | [result_ate, result_the, result_cat, result_fat], 243 | [result_a], 244 | [result_cat, result_the, result_fat], 245 | [result_The, result_the, result_cat, result_fat], 246 | [result_cat, result_the, result_fat]], 247 | results) 248 | 249 | # stats 250 | # - it is hard to disentangle the effects of randomness here, so we're 251 | # not checking any actual statistics 252 | expect_partial_match( 253 | PlainData.STATS, lmc.stats.stats(results, human=False)) 254 | expect_partial_match( 255 | PlainData.HUMAN_STATS, lmc.stats.stats(results, human=True)) 256 | -------------------------------------------------------------------------------- /lmchallenge/tests/test_grep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .. import grep 5 | import emoji 6 | 7 | 8 | def test_keep(): 9 | a = dict(user=None, message=0, target='a') 10 | b = dict(user=None, message=0, target='b', select=False) 11 | c = dict(user=None, message=1, target='c', select=False) 12 | d = dict(user=None, message=2, target='d', select=True) 13 | e = dict(user='z', message=2, target='e', select=False) 14 | assert list(grep.Keep.all([a, b, c, d, e])) == [a, b, c, d, e] 15 | assert list(grep.Keep.message([a, b, c, d, e])) == [a, b, d] 16 | assert list(grep.Keep.token([a, b, c, d, e])) == [a, d] 17 | 18 | 19 | def test_parse_pattern_target(): 20 | # TODO: parameterize 21 | for pattern, target, expected in [ 22 | # 1. Target pattern 23 | ('[gh].', 'great', True), 24 | ('[gh].', 'have', True), 25 | ('[gh].', 'leg', False), 26 | ('[gh].', 'yes', False), 27 | ('foo', 'foo', True), 28 | ('foo', 'foobar', True), 29 | ('^foo$', 'foo', True), 30 | ('^foo$', 'foobar', False), 31 | # 2. Negated target pattern 32 | ('~aa', 'have', True), 33 | ('~aa', 'abba', True), 34 | ('~aa', 'aa', False), 35 | ('~aa', 'aardvark', False), 36 | # 3. Predefined target pattern 37 | ('$nospace', 'abc-def', True), 38 | ('$nospace', 'abc\tdef', False), 39 | ('$alpha', 'abc:-0', True), 40 | ('$alpha', ':-0', False), 41 | ('$alphaonly', 'abc', True), 42 | ('$alphaonly', 'abc:-0', False), 43 | ('$emoji', 'a😊', True), 44 | ('$emoji', 'a:-)', False), 45 | ('$alphaemoji', '😊', True), 46 | ('$alphaemoji', 'a-)', True), 47 | ('$alphaemoji', ':-)', False), 48 | ]: 49 | assert grep.parse_pattern(pattern)(dict(target=target)) == expected 50 | 51 | 52 | def test_parse_pattern_other(): 53 | user_alpha = grep.parse_pattern('$user=Alpha') 54 | assert user_alpha(dict(user='Alpha', target='foo')) 55 | assert not user_alpha(dict(user='alpha', target='foo')) 56 | assert not user_alpha(dict(user='AlphaMan', target='foo')) 57 | assert not user_alpha(dict(user='ManAlpha', target='foo')) 58 | 59 | message_0123 = grep.parse_pattern('$message=[0123]') 60 | assert message_0123(dict(message=2, target='foo')) 61 | assert not message_0123(dict(message=4, target='foo')) 62 | assert not message_0123(dict(message=10, target='foo')) 63 | 64 | token_90 = grep.parse_pattern('$token=90') 65 | assert token_90(dict(token=90, target='foo')) 66 | assert not token_90(dict(token=91, target='foo')) 67 | assert not token_90(dict(token=9090, target='foo')) 68 | 69 | character_1x = grep.parse_pattern('$character=1.') 70 | assert character_1x(dict(character=18, target='foo')) 71 | assert not character_1x(dict(character=21, target='foo')) 72 | 73 | 74 | def test_parse_pattern_emoji(): 75 | emoji_pred = grep.parse_pattern('$emoji') 76 | alpha_emoji_pred = grep.parse_pattern('$alphaemoji') 77 | for emo in emoji.UNICODE_EMOJI.keys(): 78 | assert emoji_pred(dict(target=emo)) 79 | assert alpha_emoji_pred(dict(target=emo)) 80 | 81 | 82 | def test_parse_patterns_all(): 83 | pred = grep.parse_patterns_all( 84 | 'p', 85 | '$nospace', 86 | '$user=[aA].+') 87 | 88 | assert pred(dict(user='andy', target='open-sesame')) 89 | # no "p" 90 | assert not pred(dict(user='andy', target='ouvret-sesame')) 91 | # contains space 92 | assert not pred(dict(user='andy', target='open sesame')) 93 | # user doesn't start "a" 94 | assert not pred(dict(user='bandy', target='open-sesame')) 95 | -------------------------------------------------------------------------------- /lmchallenge/tests/test_pretty.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .. import pretty 5 | import json 6 | import math 7 | import pytest 8 | 9 | 10 | @pytest.mark.slow 11 | def test_get_viewer_files(): 12 | # Mainly just check we don't crash! 13 | for name, data in pretty._get_viewer_files().items(): 14 | assert isinstance(data, str) 15 | assert len(data) 16 | 17 | 18 | def test_json_dumps_min(): 19 | for document in [ 20 | None, 21 | "text", 22 | r'\escape\ntext\t\r\"', 23 | 10000, 24 | [], 25 | [123], 26 | dict(abc=1, d=None, e=12.3, g=[4, 5, 0.00000727]), 27 | ]: 28 | assert json.loads( 29 | pretty._json_dumps_min(document, float_format='.3g')) \ 30 | == document 31 | 32 | assert pretty._json_dumps_min(math.pi, '.1f') == '3.1' 33 | assert pretty._json_dumps_min(math.pi, '.3g') == '3.14' 34 | # Tuples are written as JSON lists 35 | assert pretty._json_dumps_min(('abc', 123)) == '["abc",123]' 36 | -------------------------------------------------------------------------------- /lmchallenge/tests/test_validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .. import validate 5 | import jsonschema 6 | 7 | 8 | def test_validate_schema(): 9 | jsonschema.Draft4Validator.check_schema(validate.schema()) 10 | -------------------------------------------------------------------------------- /lmchallenge/validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | '''Check that logs confirm to the correct LM Challenge schema (expressed as a 5 | JSON schema - see [http://json-schema.org](http://json-schema.org)). 6 | 7 | This can be used to check that a method of generating log files other than 8 | `lmchallenge.run` (e.g. using your own parallel execution framework) is 9 | compatible with the LM Challenge analysis tools. 10 | 11 | N.B. This validator can only check the format of the log, not the _fairness_ of 12 | the log. An example of an unfair log is a word entropy log where the total 13 | probability over the specified vocabulary for a given context is not equal to 14 | one. 15 | ''' 16 | 17 | import click 18 | import json 19 | import jsonschema 20 | import os 21 | from .core import common 22 | 23 | 24 | def schema(): 25 | '''Returns the log instance schema as a Python object. 26 | 27 | (Loaded from the schema definition file within the `lmchallenge` package.) 28 | ''' 29 | with open(os.path.join(os.path.dirname(__file__), 'log.schema')) as f: 30 | return json.load(f) 31 | 32 | 33 | def validate(data): 34 | '''Check that a loaded log conforms to the schema, using `jsonschema`. 35 | 36 | `data` -- iterable of log events, each of which should conform to 37 | `lmchallenge.validate.schema` 38 | 39 | `raises` -- `jsonschema.exceptions.ValidationError` if the log does 40 | not conform 41 | ''' 42 | log_schema = schema() 43 | for datum in data: 44 | jsonschema.validate(datum, log_schema) 45 | 46 | 47 | @click.command() 48 | @click.argument('log', nargs=-1, type=click.Path(exists=True, dir_okay=False)) 49 | @click.option('-v', '--verbose', default=0, count=True, 50 | help='How much human-readable detail to print to STDERR.') 51 | def cli(log, verbose): 52 | '''Validate a log file against the standard LM Challenge schema. 53 | ''' 54 | common.verbosity(verbose) 55 | 56 | log = log or ['-'] 57 | 58 | for single_log in log: 59 | validate(common.load_jsonlines(single_log)) 60 | 61 | 62 | __doc__ += common.shell_docstring(cli, 'lmc validate') 63 | if __name__ == '__main__': 64 | cli() 65 | -------------------------------------------------------------------------------- /requirements-base.txt: -------------------------------------------------------------------------------- 1 | click 2 | emoji 3 | jsonschema 4 | regex 5 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements-base.txt 2 | flake8 3 | pdoc 4 | pytest 5 | pygments 6 | scipy 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click==6.7 2 | emoji==0.4.5 3 | jsonschema==2.6.0 4 | regex==2017.12.12 5 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /data-small 3 | -------------------------------------------------------------------------------- /sample/ngram.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import collections 5 | import argparse 6 | import sys 7 | import math 8 | import contextlib 9 | import dbm 10 | import tempfile 11 | import struct 12 | import itertools as it 13 | import lmchallenge as lmc 14 | 15 | 16 | class Context: 17 | '''Helper object to store a context's record of children & total 18 | count. 19 | ''' 20 | __slots__ = ('count', 'children', '_child_order') 21 | 22 | def __init__(self): 23 | self.count = 0 24 | self.children = dict() 25 | self._child_order = None 26 | 27 | def set(self, child, count): 28 | if child in self.children: 29 | raise ValueError('duplicate child {}'.format(child)) 30 | self.count += count 31 | self.children[child] = count 32 | self._child_order = None 33 | 34 | def probability(self, child): 35 | return self.children.get(child, 0) / self.count 36 | 37 | def predictions(self): 38 | # Lazily compute & cache the order of children 39 | if self._child_order is None: 40 | self._child_order = sorted( 41 | self.children, key=lambda k: -self.children[k]) 42 | return self._child_order 43 | 44 | 45 | class BackoffContext: 46 | '''Helper object to represent a backed off context lookup. 47 | ''' 48 | __slots__ = ('_context_and_weight', '_total_weight') 49 | 50 | def __init__(self, context_and_weight): 51 | self._context_and_weight = [ 52 | (c, w) 53 | for c, w in context_and_weight 54 | if c is not None 55 | ] 56 | self._total_weight = sum(w for c, w in self._context_and_weight) 57 | 58 | def log_probability(self, child): 59 | '''Compute the backed off log-probability of a single child 60 | in this backed off context. 61 | 62 | child -- term to look up in this context 63 | 64 | returns -- `None` if the context or child was missing, 65 | float probability otherwise 66 | ''' 67 | sum_score = sum(c.probability(child) * w 68 | for c, w in self._context_and_weight) 69 | return (None 70 | if sum_score == 0 else 71 | math.log(sum_score / self._total_weight)) 72 | 73 | def predictions(self): 74 | '''Get next-token predictions from this context. 75 | 76 | returns -- a generator of string tokens 77 | ''' 78 | return (prediction 79 | for context, _ in self._context_and_weight[::-1] 80 | for prediction in context.predictions()) 81 | 82 | 83 | class NgramModel: 84 | '''An n-gram model of sequences of tokens, with interpolation backoff 85 | and approximate prediction (which isn't backed off). 86 | ''' 87 | def __init__(self, contexts, order_weights): 88 | '''Create an NgramModel. See `create()`. 89 | 90 | contexts -- dict[tuple(context..) -> Context()] 91 | 92 | order_weights -- list of weights [unigram, bigram, trigram, ...] 93 | ''' 94 | self._contexts = contexts 95 | self._order_weights = order_weights 96 | 97 | @classmethod 98 | def create(cls, ngrams, order_weights): 99 | '''Create the ngram sequence model from a flat sequence of 100 | (ngram, count) pairs. 101 | 102 | ngrams -- iterable of (ngram, count) pairs 103 | 104 | order_weights -- list of weights [unigram, bigram, trigram, ...] 105 | (if too short, the last weight is used for 106 | every higher order) 107 | ''' 108 | contexts = dict() 109 | order = 1 110 | for (*context, head), count in ngrams: 111 | context = tuple(context) 112 | if context not in contexts: 113 | contexts[context] = Context() 114 | contexts[context].set(head, count) 115 | order = max(order, len(context) + 1) 116 | 117 | order_weights = list(it.islice( 118 | it.chain(order_weights, it.repeat(order_weights[-1])), 119 | order 120 | )) 121 | return cls(contexts, order_weights) 122 | 123 | def lookup(self, context): 124 | '''Lookup a context and return a BackoffContext instance, which 125 | can be used to score candidates, or enumerate predictions. 126 | 127 | context -- sequence of tokens (does not need to be padded) 128 | ''' 129 | # add padding to the start 130 | context = tuple( 131 | it.repeat('\x1d', len(self._order_weights) - 1) 132 | ) + tuple(context) 133 | 134 | return BackoffContext([ 135 | (self._contexts.get(() if n == 0 else context[-n:]), 136 | weight) 137 | for n, weight in enumerate(self._order_weights) 138 | ]) 139 | 140 | 141 | class WordModel(lmc.FilteringWordModel): 142 | '''A simple ngram word model. 143 | ''' 144 | def __init__(self, ngrams, order_weights, n_predictions): 145 | '''Create the word model from a flat sequence of 146 | (ngram, count) pairs. 147 | 148 | ngrams -- iterable of (ngram, count) pairs 149 | 150 | order_weights -- list of weights for unigram, bigram, etc. 151 | (if too short, the last weight is used for 152 | every higher order) 153 | ''' 154 | super().__init__(n_predictions=n_predictions) 155 | self._model = NgramModel.create(ngrams, order_weights) 156 | 157 | def predict_word_iter(self, context): 158 | backoff = self._model.lookup(context) 159 | # Don't bother computing "proper" scores 160 | # (backoff.log_probability(word)) for performance reasons 161 | # - as there is no need in this case, so just create fake 162 | # scores (-rank) 163 | return ((word, -n) 164 | for n, word in enumerate(backoff.predictions())) 165 | 166 | def score_word(self, context, candidates): 167 | backoff = self._model.lookup(context) 168 | return [(candidate, backoff.log_probability(candidate)) 169 | for candidate in candidates] 170 | 171 | 172 | class CharacterModel(lmc.Model): 173 | '''A simple ngram character model 174 | (only supporting scoring, not prediction). 175 | ''' 176 | def __init__(self, ngrams, order_weights): 177 | '''Create the character model from a flat sequence of 178 | (ngram, count) pairs. 179 | 180 | ngrams -- iterable of (ngram, count) pairs 181 | 182 | order_weights -- list of weights for unigram, bigram, etc. 183 | (if too short, the last weight is used for 184 | every higher order) 185 | ''' 186 | self._model = NgramModel.create(ngrams, order_weights) 187 | 188 | def predict(self, context, candidates): 189 | backoff = self._model.lookup(context) 190 | return [(candidate, backoff.log_probability(candidate)) 191 | for candidate in (candidates or [])] 192 | 193 | 194 | def parse_ngram(line): 195 | '''Parse a string-encoded ngram. 196 | 197 | line -- string -- e.g. "aaa\x1ebbb\x1e777\n" 198 | 199 | returns -- (ngram, count) pair -- e.g. (("aaa", "bbb"), 777) 200 | ''' 201 | *ngram, count = line.rstrip('\n').split('\x1e') 202 | return tuple(ngram), int(count) 203 | 204 | 205 | class DictCounter: 206 | '''A simple, memory-hungry counter, backed by a Python dictionary. 207 | ''' 208 | def __init__(self): 209 | self._d = collections.defaultdict(int) 210 | 211 | @classmethod 212 | @contextlib.contextmanager 213 | def open(cls): 214 | yield cls() 215 | 216 | def increment(self, key): 217 | self._d[key] += 1 218 | 219 | def items(self): 220 | return self._d.items() 221 | 222 | 223 | class DbmCounter: 224 | '''A slow counter backed by a database. 225 | ''' 226 | FORMAT = '>I' 227 | 228 | def __init__(self, db): 229 | self._db = db 230 | 231 | @classmethod 232 | @contextlib.contextmanager 233 | def open(cls): 234 | with tempfile.NamedTemporaryFile() as f: 235 | with dbm.open(f.name, 'n') as db: 236 | yield cls(db) 237 | 238 | def increment(self, key): 239 | key = key.encode('utf8') 240 | count = self._db.get(key) 241 | count = (1 242 | if count is None else 243 | struct.unpack(self.FORMAT, count)[0] + 1) 244 | self._db[key] = struct.pack(self.FORMAT, count) 245 | 246 | def items(self): 247 | key = self._db.firstkey() 248 | while key is not None: 249 | yield (key.decode('utf8'), 250 | struct.unpack(self.FORMAT, self._db[key])[0]) 251 | key = self._db.nextkey(key) 252 | 253 | 254 | def sequence(lines, order, counter=None): 255 | '''"Sequence up" the input lines into ngrams of the order "order". 256 | 257 | lines -- an iterable of lists of tokens 258 | 259 | order -- int 260 | 261 | returns -- an iterable of (ngram, count) pairs, where ngram is a 262 | string separated by ASCII record separator (RS) \x1E 263 | note that the start-of-sequence is padded with (order-1) 264 | ASCII group separator (GS) \x1D 265 | ''' 266 | if counter is None: 267 | counter = DictCounter.open() 268 | pad = list(it.repeat('\x1d', order - 1)) 269 | for line in lines: 270 | line = pad + line 271 | for n in range(order - 1, len(line)): 272 | for i in range(order): 273 | counter.increment('\x1e'.join(line[(n - i):(n + 1)])) 274 | return counter.items() 275 | 276 | 277 | # Command line wrappers 278 | 279 | def sequence_cli(order, disk, tokenizer): 280 | '''Command line version of `sequence`, applying a tokenizer regex, 281 | between stdin & stdout. 282 | ''' 283 | lines = ([m.group(0) for m in tokenizer.finditer(line.rstrip('\n'))] 284 | for line in sys.stdin) 285 | with (DbmCounter if disk else DictCounter).open() as counter: 286 | for ngram, count in sequence(lines, order, counter): 287 | sys.stdout.write('{}\x1e{}\n'.format(ngram, count)) 288 | 289 | 290 | def sequence_words_cli(order, disk): 291 | '''Command line version of `sequence` for words, between stdin & stdout. 292 | ''' 293 | sequence_cli(order, disk, lmc.core.common.WORD_TOKENIZER) 294 | 295 | 296 | def sequence_chars_cli(order, disk): 297 | '''Command line version of `sequence` for characters, between stdin & stdout. 298 | ''' 299 | sequence_cli(order, disk, lmc.core.common.CHARACTER_TOKENIZER) 300 | 301 | 302 | def prune_cli(count): 303 | '''Command line for pruning ngrams that are below a minimum count. 304 | ''' 305 | for line in sys.stdin: 306 | if count <= int(line.rstrip('\n').split('\x1e')[-1]): 307 | sys.stdout.write(line) 308 | 309 | 310 | def words_cli(ngrams, weights, n_predictions): 311 | '''Start a word model prediction loop. 312 | ''' 313 | with open(ngrams) as f: 314 | WordModel(map(parse_ngram, f), weights, n_predictions).run_loop() 315 | 316 | 317 | def chars_cli(ngrams, weights): 318 | '''Start a character model prediction loop. 319 | ''' 320 | with open(ngrams) as f: 321 | CharacterModel(map(parse_ngram, f), weights).run_loop() 322 | 323 | 324 | # Command line 325 | 326 | if __name__ == '__main__': 327 | parser = argparse.ArgumentParser( 328 | description='Example ngram language model' 329 | ) 330 | subparsers = parser.add_subparsers() 331 | 332 | s = subparsers.add_parser('sequence-words', help='sequence up word ngrams') 333 | s.add_argument('order', type=int, help='order to sequence up to') 334 | s.add_argument('-d', '--disk', action='store_true', 335 | help='use a slow on-disk sequencer') 336 | s.set_defaults(execute=sequence_words_cli) 337 | 338 | s = subparsers.add_parser('sequence-chars', 339 | help='sequence up character ngrams') 340 | s.add_argument('order', type=int, help='order to sequence up to') 341 | s.add_argument('-d', '--disk', action='store_true', 342 | help='use a slow on-disk sequencer') 343 | s.set_defaults(execute=sequence_chars_cli) 344 | 345 | s = subparsers.add_parser('prune', help='prune down ngrams') 346 | s.add_argument('count', type=int, help='minimum count to allow') 347 | s.set_defaults(execute=prune_cli) 348 | 349 | s = subparsers.add_parser('words', 350 | help='start a character model predictor') 351 | s.add_argument('ngrams', help='file path to ngrams dataset') 352 | s.add_argument('-n', '--n-predictions', default=100, type=int, 353 | help='number of predictions to return') 354 | s.add_argument('-w', '--weights', nargs='+', type=float, 355 | default=[1, 2, 2], 356 | help='weights to apply to each order of prediction' 357 | ' (starting with unigram)') 358 | s.set_defaults(execute=words_cli) 359 | 360 | s = subparsers.add_parser('chars', 361 | help='start a character model predictor') 362 | s.add_argument('ngrams', help='file path to ngrams dataset') 363 | s.add_argument('-w', '--weights', nargs='+', type=float, 364 | default=[1, 1, 10, 100, 1000], 365 | help='weights to apply to each order of prediction' 366 | ' (starting with unigram)') 367 | s.set_defaults(execute=chars_cli) 368 | 369 | args = vars(parser.parse_args()) 370 | args.pop('execute')(**args) 371 | -------------------------------------------------------------------------------- /sample/prepare.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | echo "# Downloading data, and building ngram models..." 5 | echo "# (N.B. this should be run from the sample/ directory)" 6 | echo 7 | 8 | DATA="data" 9 | 10 | mkdir -p ${DATA} 11 | 12 | echo "# Downloading data" 13 | wget --show-progress -qO "${DATA}/raw.zip" "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" 14 | unzip -qj "${DATA}/raw.zip" -d "${DATA}" 15 | 16 | echo "# Creating correction vocab" 17 | cat ${DATA}/*.tokens \ 18 | | tr ' ' '\n' \ 19 | | sort \ 20 | | uniq -c \ 21 | | sort -nr \ 22 | | awk '{print $2}' \ 23 | > ${DATA}/vocab.txt 24 | 25 | echo "# Counting word ngrams" 26 | cat ${DATA}/wiki.train.tokens \ 27 | | env PYTHONPATH=.. python3 ngram.py sequence-words 3 \ 28 | | env PYTHONPATH=.. python3 ngram.py prune 3 \ 29 | | tee "${DATA}/words.3gram" \ 30 | | awk -F'\x1e' '{if (NF <= 3) print $0}' \ 31 | | tee "${DATA}/words.2gram" \ 32 | | awk -F'\x1e' '{if (NF == 2) print $0}' \ 33 | > "${DATA}/words.1gram" 34 | 35 | echo "# Counting character ngrams" 36 | echo "# (N.B. if this were serious, it should use the untokenized data)" 37 | cat ${DATA}/wiki.train.tokens \ 38 | | env PYTHONPATH=.. python3 ngram.py sequence-chars 5 \ 39 | | env PYTHONPATH=.. python3 ngram.py prune 3 \ 40 | > "${DATA}/chars.5gram" 41 | -------------------------------------------------------------------------------- /sample/prepare_big.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | echo "# Downloading data, and building ngram models..." 5 | echo "# (N.B. this should be run from the sample/ directory)" 6 | echo 7 | 8 | DATA="data-big" 9 | 10 | mkdir -p ${DATA} 11 | 12 | echo "# Downloading data" 13 | wget -O "${DATA}/raw.tar.gz" "http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz" 14 | tar -xf "${DATA}/raw.tar.gz" -C "${DATA}" --strip-components=1 15 | 16 | echo "# Creating test set" 17 | cat ${DATA}/heldout-monolingual.tokenized.shuffled/news.en.heldout-* \ 18 | > ${DATA}/test.txt 19 | 20 | echo "# Creating test vocab" 21 | cat ${DATA}/test.txt \ 22 | | tr ' ' '\n' \ 23 | | sort \ 24 | | uniq -c \ 25 | | sort -nr \ 26 | | awk '{print $2}' \ 27 | | head -n 100000 \ 28 | > ${DATA}/test.vocab.100k.txt 29 | 30 | echo "# Setting memory limit at 3 GB" 31 | ulimit -Sv 3000000 32 | 33 | echo "# Counting word ngrams" 34 | time cat ${DATA}/training-monolingual.tokenized.shuffled/news.en-0000* \ 35 | | env PYTHONPATH=.. python3 ngram_example.py sequence-words 3 --disk \ 36 | | env PYTHONPATH=.. python3 ngram_example.py prune 3 \ 37 | > "${DATA}/words.ngram" 38 | 39 | echo "# Counting character ngrams" 40 | time cat ${DATA}/training-monolingual.tokenized.shuffled/news.en-0000* \ 41 | | env PYTHONPATH=.. python3 ngram_example.py sequence-chars 5 \ 42 | | env PYTHONPATH=.. python3 ngram_example.py prune 3 \ 43 | > "${DATA}/chars.ngram" 44 | -------------------------------------------------------------------------------- /scripts/Dockerfile.base: -------------------------------------------------------------------------------- 1 | FROM python:3.5 2 | 3 | ENV LC_ALL=C.UTF-8 \ 4 | LANG=C.UTF-8 5 | 6 | ADD requirements-base.txt /tmp/lmc/ 7 | RUN cd /tmp/lmc \ 8 | && pip3 install --upgrade -r requirements-base.txt \ 9 | && rm -r /tmp/lmc 10 | -------------------------------------------------------------------------------- /scripts/Dockerfile.dev: -------------------------------------------------------------------------------- 1 | FROM python:3.5 2 | 3 | ENV LC_ALL=C.UTF-8 \ 4 | LANG=C.UTF-8 5 | 6 | ADD requirements-base.txt requirements-dev.txt /tmp/lmc/ 7 | RUN cd /tmp/lmc \ 8 | && pip3 install --upgrade -r requirements-dev.txt \ 9 | && rm -r /tmp/lmc 10 | -------------------------------------------------------------------------------- /scripts/Dockerfile.notebook: -------------------------------------------------------------------------------- 1 | FROM jupyter/scipy-notebook 2 | 3 | USER root 4 | ADD . /home/jovyan/lmchallenge 5 | RUN cd /home/jovyan/lmchallenge \ 6 | && python3 setup.py install \ 7 | && pip3 install -r requirements-dev.txt 8 | 9 | USER jovyan 10 | -------------------------------------------------------------------------------- /scripts/run: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Microsoft Corporation. All rights reserved. 4 | # Licensed under the MIT license. 5 | # 6 | 7 | import argparse 8 | import subprocess 9 | import string 10 | import os 11 | import sys 12 | import itertools as it 13 | 14 | 15 | def sh(cmd, **args): 16 | code = subprocess.call(string.Template(cmd).substitute(**args), 17 | shell=True) 18 | if code != 0: 19 | exit(code) 20 | 21 | 22 | def image_tag(image): 23 | return dict( 24 | prod='lmchallenge', 25 | base='lmchallenge-base', 26 | dev='lmchallenge-dev', 27 | notebook='lmchallenge-notebook' 28 | ).get(image, image) 29 | 30 | 31 | def build(args): 32 | '''Build or pull the Docker image, which may be: 33 | "prod" -- build the production image 34 | "base" -- build the base image 35 | "dev" -- build the development image 36 | "other/name" -- pull a remote image 37 | ''' 38 | if args.image in ['prod', 'base', 'dev', 'notebook']: 39 | sh('docker build --rm ${FLAGS} -t ${TAG} -f ${DOCKERFILE} .', 40 | TAG=image_tag(args.image), 41 | FLAGS='--no-cache' if args.no_cache else '', 42 | DOCKERFILE=( 43 | 'Dockerfile' 44 | if args.image == 'prod' else 45 | ('scripts/Dockerfile.' + args.image))) 46 | else: 47 | sh('docker pull %s' % args.image) 48 | 49 | 50 | def refreeze(args): 51 | '''Refreeze requirements.txt. 52 | ''' 53 | if args.image != 'base': 54 | sys.stderr.write( 55 | 'WARNING: freezing dependencies should generally be done' 56 | ' from the "base" image, not "{}"\n'.format(args.image)) 57 | sh('docker run --rm -i -v ${WORK}:/work -w /work' 58 | ' ${TAG} pip freeze > requirements.txt', 59 | WORK=os.getcwd(), 60 | TAG=image_tag(args.image)) 61 | 62 | 63 | def run(args): 64 | 'Run an arbitrary command.' 65 | sh('docker run --rm -i -v ${WORK}:/work -w /work' 66 | ' -v /var/run/docker.sock:/var/run/docker.sock' 67 | ' ${TAG} ${CMD}', 68 | WORK=os.getcwd(), 69 | TAG=image_tag(args.image), 70 | CMD=' '.join(args.command)) 71 | 72 | 73 | def test(args): 74 | 'Run unit & functional tests.' 75 | sh('docker run --rm -i -v ${WORK}:/work -w /work ${TAG}' 76 | ' pytest lmchallenge --doctest-modules ${SLOW} ${PYTEST_ARGS}', 77 | WORK=os.getcwd(), 78 | TAG=image_tag(args.image), 79 | SLOW='--run-slow' if args.slow else '', 80 | PYTEST_ARGS=' '.join(args.pytest_args)) 81 | 82 | 83 | def flake(args): 84 | 'Run style checker.' 85 | sh('docker run --rm -i -v ${WORK}:/work -w /work ${TAG}' 86 | ' flake8', 87 | WORK=os.getcwd(), 88 | TAG=image_tag(args.image)) 89 | 90 | 91 | COPYRIGHT_NOTICE = [ 92 | '# Copyright (c) Microsoft Corporation. All rights reserved.', 93 | '# Licensed under the MIT license.' 94 | ] 95 | 96 | 97 | def check_copyright(path): 98 | # Strict verbatim test - the file must start with these lines 99 | with open(path) as f: 100 | lines = list(l.rstrip('\n') for l in it.islice(f, 2)) 101 | if lines != COPYRIGHT_NOTICE: 102 | sys.stderr.write( 103 | ('Error! Bad copyright notice in {}:' 104 | ' {}\n').format( 105 | path, 106 | '\\n'.join(lines) 107 | )) 108 | return lines == COPYRIGHT_NOTICE 109 | 110 | 111 | def copyright(args): 112 | 'Check for copyright headers.' 113 | errors = sum( 114 | 1 115 | for base in ['lmchallenge', 'sample'] 116 | for root, _, files in os.walk(base) 117 | for name in files 118 | if os.path.splitext(name)[-1] == '.py' 119 | if not check_copyright(os.path.join(root, name))) 120 | if errors: 121 | exit(errors) 122 | 123 | 124 | def check(args): 125 | 'Run tests & static analysis.' 126 | test(args) 127 | flake(args) 128 | copyright(args) 129 | sys.stderr.write('*** All Checks passed ***\n') 130 | 131 | 132 | def doc(args): 133 | 'Generate documentation.' 134 | OUT = 'site/' 135 | sh('docker run --rm -i -v ${WORK}:/work -w /work ${TAG}' 136 | ' env PYTHONPATH=.' 137 | ' pdoc --overwrite --html --html-dir ${OUT} lmchallenge', 138 | WORK=os.getcwd(), 139 | TAG=image_tag(args.image), 140 | OUT=OUT) 141 | sys.stderr.write('Documentation generated: {}\n'.format(OUT)) 142 | 143 | 144 | def notebook(args): 145 | 'Start a notebook server with LMC loaded.' 146 | if args.restart: 147 | # Ignore failures e.g. not running already 148 | sh('docker rm -f ${NAME} || true', NAME=args.name) 149 | 150 | # IPython.lib.passwd(passphrase='lmc') 151 | password = 'sha1:d9c0350bce66:7ab513830b9a6688ed423c65f486c8d08b13718c' 152 | 153 | sh('docker run -d -v ${WORK}:/work -w /work' 154 | ' --name ${NAME} -p ${PORT}:${PORT} ${TAG}' 155 | ' sh -c "jupyter notebook --allow-root' 156 | ' --port ${PORT} --ip \'*\'' 157 | ' --NotebookApp.password=\'${PASSWORD}\'"', 158 | WORK=os.getcwd(), 159 | NAME=args.name, 160 | PASSWORD=password, 161 | PORT=args.port, 162 | TAG=image_tag(args.image)) 163 | 164 | sys.stderr.write(' Server: http://localhost:{}\n'.format(args.port)) 165 | sys.stderr.write('Password: lmc\n') 166 | 167 | 168 | parser = argparse.ArgumentParser( 169 | description='Builder, runner, tester for lmchallenge development', 170 | ) 171 | parser.add_argument( 172 | '-i', '--image', 173 | help='Which image to use:' 174 | ' "base" | "prod" | "dev" | "notebook" | "other/name"', 175 | type=str, 176 | default='dev' 177 | ) 178 | parser.set_defaults(action=lambda args: parser.print_help()) 179 | subparsers = parser.add_subparsers() 180 | 181 | sparser = subparsers.add_parser('build', help=build.__doc__) 182 | sparser.add_argument('--no-cache', action='store_true', 183 | help='rebuild the image from scratch') 184 | sparser.set_defaults(action=build) 185 | 186 | sparser = subparsers.add_parser('refreeze', help=refreeze.__doc__) 187 | sparser.set_defaults(action=refreeze) 188 | 189 | sparser = subparsers.add_parser('run', help=run.__doc__) 190 | sparser.add_argument('command', nargs='*') 191 | sparser.set_defaults(action=run) 192 | 193 | sparser = subparsers.add_parser('test', help=test.__doc__) 194 | sparser.add_argument('--slow', action='store_true', help='run slow tests') 195 | sparser.add_argument('pytest_args', nargs='*', default=[]) 196 | sparser.set_defaults(action=test) 197 | 198 | sparser = subparsers.add_parser('flake', help=flake.__doc__) 199 | sparser.set_defaults(action=flake) 200 | 201 | sparser = subparsers.add_parser('check', help=check.__doc__) 202 | sparser.add_argument('--no-slow', dest='slow', action='store_false', 203 | help='don\'t run slow tests') 204 | sparser.add_argument('pytest_args', nargs='*', default=[]) 205 | sparser.set_defaults(action=check) 206 | 207 | sparser = subparsers.add_parser('doc', help=doc.__doc__) 208 | sparser.set_defaults(action=doc) 209 | 210 | sparser = subparsers.add_parser('copyright', help=copyright.__doc__) 211 | sparser.set_defaults(action=copyright) 212 | 213 | sparser = subparsers.add_parser('notebook', help=notebook.__doc__) 214 | sparser.add_argument('-n', '--name', 215 | help='Notebook container name', 216 | default='lmchallenge-notebook') 217 | sparser.add_argument('-r', '--restart', 218 | help='Stop & restart the notebook', 219 | action='store_true') 220 | sparser.add_argument('-p', '--port', help='Which port to use', 221 | default=8888) 222 | sparser.set_defaults(action=notebook) 223 | 224 | args = parser.parse_args() 225 | args.action(args) 226 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude=build, 3 | doc, 4 | local, 5 | .eggs 6 | 7 | [metadata] 8 | description-file=README.md 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('version.txt', 'r') as v: 4 | version = v.readline().strip() 5 | 6 | with open('README.md', 'r') as r: 7 | readme = r.read() 8 | 9 | with open('requirements.txt', 'r') as r: 10 | requirements = list(x.strip() for x in r) 11 | 12 | setup(name='lmchallenge', 13 | version=version, 14 | description='LM Challenge' 15 | ' - A library & tools to evaluate predictive language models.', 16 | long_description=readme, 17 | long_description_content_type="text/markdown", 18 | 19 | url='https://github.com/Microsoft/LMChallenge', 20 | author='Microsoft Corporation', 21 | author_email='swiftkey-deep@service.microsoft.com', 22 | license='MIT', 23 | 24 | packages=['lmchallenge', 'lmchallenge.core'], 25 | include_package_data=True, 26 | install_requires=requirements, 27 | entry_points=''' 28 | [console_scripts] 29 | lmc=lmchallenge:cli 30 | 31 | lmc-diff=lmchallenge.diff:cli 32 | lmc-grep=lmchallenge.grep:cli 33 | lmc-pretty=lmchallenge.pretty:cli 34 | lmc-run=lmchallenge.run:cli 35 | lmc-stats=lmchallenge.stats:cli 36 | lmc-validate=lmchallenge.validate:cli 37 | ''') 38 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 5.2 2 | --------------------------------------------------------------------------------