├── .github └── workflows │ └── tox.yml ├── .gitignore ├── .readthedocs.yaml ├── CHANGELOG.md ├── LICENSE ├── LICENSE_nltk.txt ├── LICENSE_pytorch.txt ├── README.md ├── docs ├── Makefile ├── cli_helper.py ├── make.bat ├── requirements.txt └── source │ ├── api │ └── pydrobert │ │ ├── torch.rst │ │ └── torch │ │ ├── argcheck.rst │ │ ├── config.rst │ │ ├── data.rst │ │ ├── distributions.rst │ │ ├── estimators.rst │ │ ├── functional.rst │ │ ├── lightning.rst │ │ ├── modules.rst │ │ └── training.rst │ ├── cli.rst │ ├── conf.py │ ├── index.rst │ ├── references.rst │ └── tutorials │ ├── advanced-attn.rst │ └── lm.rst ├── environment.yml ├── pyproject.toml ├── pytest.ini ├── setup.cfg ├── src └── pydrobert │ └── torch │ ├── __init__.py │ ├── _attn.py │ ├── _combinatorics.py │ ├── _compat.py │ ├── _dataloaders.py │ ├── _datasets.py │ ├── _decoding.py │ ├── _enumerate_estimator.py │ ├── _estimators.py │ ├── _feats.py │ ├── _img.py │ ├── _lm.py │ ├── _mc.py │ ├── _pad.py │ ├── _parsing.py │ ├── _pl_data.py │ ├── _rl.py │ ├── _straight_through.py │ ├── _string.py │ ├── _textgrid.py │ ├── _wrappers.py │ ├── argcheck.py │ ├── command_line.py │ ├── config.py │ ├── data.py │ ├── distributions.py │ ├── estimators.py │ ├── functional.py │ ├── layers.py │ ├── lightning.py │ ├── modules.py │ ├── training.py │ └── util.py ├── tests ├── conftest.py ├── dense_image_warp │ ├── flow.npy │ ├── img.npy │ └── warped.npy ├── polyharmonic_spline │ ├── o1.npy │ ├── o2.npy │ ├── o3.npy │ ├── q.npy │ ├── x.npy │ └── y.npy ├── republic │ ├── README │ ├── exp.txt │ ├── queries.txt │ ├── republic.arpa │ ├── token2id.map │ └── vocab.txt ├── sclite │ ├── README │ ├── hyp.trn │ ├── per_utt.txt │ ├── ref.trn │ ├── sclite_out.txt │ ├── token2id.txt │ └── total.txt ├── sparse_image_warp │ ├── dst.npy │ ├── flow_0.npy │ ├── flow_2.npy │ ├── img.npy │ ├── src.npy │ ├── warped_0.npy │ └── warped_2.npy ├── test_argcheck.py ├── test_attn.py ├── test_combinatorics.py ├── test_command_line.py ├── test_dataloaders.py ├── test_datasets.py ├── test_decoding.py ├── test_enumerate_estimator.py ├── test_feats.py ├── test_img.py ├── test_lm.py ├── test_mc.py ├── test_metadata.py ├── test_pad.py ├── test_parsing.py ├── test_pl_data.py ├── test_rl.py ├── test_straight_through.py ├── test_string.py └── test_training.py └── tox.ini /.github/workflows/tox.yml: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/tox-dev/tox-gh 2 | 3 | name: tox 4 | on: 5 | push: 6 | branches-ignore: 7 | - docs 8 | pull_request: 9 | schedule: 10 | # run every monday @ 8am 11 | - cron: "0 8 * * 1" 12 | 13 | concurrency: 14 | group: tox-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | test: 19 | name: Python ${{ matrix.py }} 20 | runs-on: ubuntu-latest 21 | timeout-minutes: 15 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | py: 26 | - "3.11" 27 | - "3.10" 28 | - "3.9" 29 | - "3.8" 30 | - "3.7" 31 | steps: 32 | - uses: actions/checkout@v3 33 | with: 34 | fetch-depth: 0 35 | - name: Setup python ${{ matrix.py }} 36 | uses: actions/setup-python@v4 37 | with: 38 | python-version: ${{ matrix.py }} 39 | - name: Install tox 40 | run: python -m pip install tox-gh>=1.2 41 | - name: Setup test suite 42 | run: tox -vv --notest 43 | - name: Run test suite 44 | run: tox --skip-pkg-install 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .ftpignore 107 | .ftpconfig 108 | .vscode 109 | _version.py -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | builder: html 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt 30 | - method: pip 31 | path: . 32 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change log 2 | 3 | ## v0.4.0 4 | 5 | - Support `param` version 2 6 | - Added `pydrobert.torch.argcheck` and standardized module argument checking a 7 | bit. 8 | - Added `environment.yml` for local dev. 9 | - `LookupLanguageModel` has been refactored and reimplemented. It is no longer 10 | compatible with previous versions of the model. 11 | - `parse_arpa` has been enhanced: it can handle log-probs in scientific 12 | notation (e.g. `1e4`), conversion from (implicitly) log-base-10 to log-base-e 13 | probabilities; and storing log probabilities as Numpy floating-point types. 14 | - `LookupLanguageModel` and `parse_arpa` now have an optional logger argument 15 | to log progress building the trie and parsing the file, respectively. 16 | - Added `best_is_train` flag to `TrainingController.update_for_epoch` 17 | - Refactored `get-torch-spect-data-dir-info` to be faster 18 | - `subset-torch-spect-data-dir` command has been added 19 | - `print-torch-{ali,ref}-data-dir-length-moments` commands have been added 20 | - `LookupLanguageMode.prob_list` has been renamed to `prob_dicts` 21 | - Added `ShallowFusionLanguageModel`, `ExtractableShallowFusionLanguageModel`, 22 | and `MixableShallowFusionLanguageModel` 23 | - Slicing and chunking modules `SliceSpectData`, `ChunkBySlices`, and 24 | `ChunkTokenSequenceBySlices`, as well as the command 25 | `chunk-torch-spect-data-dir`, which puts them together. 26 | - Code for handling TextGrid files, including the functions `read_textgrid` and 27 | `write_textgrid`, as well as the commands `torch-token-data-dir-to-textgrids` 28 | and `textgrids-to-torch-token-data-dir`. 29 | - Commands for switching between ref and ali format: 30 | `torch-ali-data-dir-to-torch-token-data-dir` and 31 | `torch-token-data-dir-to-torch-ali-data-dir`. 32 | - Added py 3.10 support; removed py 3.6 support. 33 | - Initial (undocumented) support for 34 | [PyTorch-Lightning](https://www.pytorchlightning.ai/) in 35 | `pydrobert.torch.lightning` submodule. Will document when I get some time. 36 | - Refactored much of `pydrobert.torch.data`. Best just to look at the API. 37 | `ContextWindowEvaluationDataLoader`, `ContextWindowTrainingDataLoader`, 38 | `SpectEvaluationDataLoader`, `SpectTrainingDataLoader`, `DataSetParams`, 39 | `SpectDataSetParams`, and `ContextWindowDataSetParams` are now deprecated. 40 | The data loaders have been simplified to `ContextWindowDataLoader` and 41 | `SpectDataLoader`. Keyword arguments (like `shuffle`) now control their 42 | behaviour. The `*DataSetParams` have been renamed `*DataLoaderParams` with 43 | some of the parameters moved around. Notably, `LangDataParams` now stores 44 | `sos`, `eos`, and `subset_ids` parameters, from which a number of parameter 45 | objects inherit. `SpectDataLoaderParams` inherits from 46 | `LangDataLoaderParams`, which in turn inherits from 47 | `DynamicLengthDataLoaderParams`. The latter allows the loader's batch 48 | elements to be bucketed by length using the new `BucketBatchSampler`. It and 49 | a number of other samplers inherit from `AbstractEpochSampler` to help 50 | facilitate the simplified loaders and better resemble the PyTorch API. 51 | Mean-variance normalization of features is possible through the loaders and 52 | the new `MeanVarianceNormalization` module. `LangDataSet` and 53 | `LangDataLoader` have been introduced to facilitate language mdoel training. 54 | Finally, loaders (and samplers) are compatible with `DistributedDataParallel` 55 | environments. 56 | - Mean-variance statistics for normalization may be estimated from a data 57 | partition using the command `compute-mvn-stats-for-torch-feat-data-dir`. 58 | - Added `torch-spect-data-dir-to-wds` to convert a data dir to a 59 | [WebDataset](https://github.com/webdataset/webdataset). 60 | - Changed method of constructing random state in `EpochRandomSampler`. 61 | Rerunning training on this new version with the same seed will likely result 62 | in different results from the old version! 63 | - `FeatureDeltas` now a module, in case you want to compute them online rather 64 | than waste disk space. 65 | - Added `PadMaskedSequence`. 66 | - Added `FillAfterEndOfSequence`. 67 | - Added `binomial_coefficient`, `enumerate_binary_sequences`, 68 | `enumerate_vocab_sequences`, and 69 | `enumerate_binary_sequences_with_cardinality`. 70 | - Docstrings updated to hopefully be clearer. Use "Call Parameters" and 71 | "Returns" sections for pytorch modules. 72 | - readthedocs updated. 73 | - Fixed up formatting of CLI help documentation. 74 | - Data sets can now initialize some of their parameters with the values in 75 | their associated param containers. For example, `sos` and `eos` are now 76 | set in `SpectDataSet` by passing an optional `SpectDataParam` instance. The 77 | old method (by argument) is now deprecated. 78 | - Renamed `DataSetParams` to `DataLoaderParams` and deprecated former naming 79 | to better mesh with their use in data loaders. 80 | - Moved `pydrobert.torch.util.parse_arpa_lm` to `pydrobert.torch.data` 81 | - `SimpleRandomSamplingWithoutReplacement` has been added as a new 82 | distribution. 83 | - `EnumerateEstimator`, `ImportanceSamplingEstimator`, and 84 | `IndependentMetropolisHastingsEstimator` have been added as a new estimators. 85 | - `pydrobert.torch.estimators` has been rewritten from the ground-up, with old 86 | functionality deprecated. Distribution-related functions have been rewritten 87 | as `torch.distributions.Distribution` classes implementing a 88 | `ConditionalStraightThrough` interface and stored in 89 | `pydrobert.torch.distributions`. The REINFORCE and RELAX estimators now have 90 | an object-oriented interface subclassing `MonteCarloEstimator` as 91 | `DirectEstimator` and `RelaxEstimator`, respectively. The REBAR control 92 | variate is now distribution-specific and found in `pydrobert.torch.modules`. 93 | - Bug fixes to `OptimalCompletion` and `HardOptimalCompletionDistillationLoss` 94 | involving batch sizes. 95 | - Refactored code to move modules to `pydrobert.torch.modules` and functions 96 | to `pydrobert.torch.functional`. 97 | - Deprecated `pydrobert.torch.layers` and `pydrobert.torch.util`. 98 | - Added a number of modules to `pydrobert.torch.modules` to wrap functional 99 | API. Moved docstrings to modules. 100 | - Fixed a problem with `warp_1d_grid`/`SpecAugment` which made it sensitive 101 | to the length of other elements in the batch. 102 | - Added compatibility wrappers to avoid warnings across supported pytorch 103 | versions. 104 | - Refactored code and added tests to support JIT tracing and scripting for most 105 | functions/modules in pytorch >= 1.8.1. 106 | before the next release. I'll write up documentation shortly. 107 | - Added `pydrobert.torch.config` to store constants used in the module. 108 | - Removed `setup.py`. 109 | - Deleted conda recipe in prep for [conda-forge](https://conda-forge.org/). 110 | - Compatibility/determinism fixes for 1.5.1. 111 | - Bump minimum PyTorch version to 1.5.1. Actually testing this minimum! 112 | - `version.py` -> `_version.py`. 113 | - A number of modifications and additions related to decoding and language 114 | models, including: 115 | - `beam_search_advance` and `random_walk_advance` have been simplified, with 116 | much of the end-of-sequence logic punted to their associated modules. 117 | - Rejigged `SequentialLanguageModel` and `LookupLanguageModel` to be both 118 | simpler and compatible with decoder interfaces. 119 | - `ctc_greedy_search` and `ctc_prefix_search_advance` functions have been 120 | added. 121 | - `ExtractableSequentialLanguageModel`, `MixableSequentialLanguageModel`, 122 | `BeamSearch`, `RandomWalk`, and `CTCPrefixSearch` modules have been added. 123 | - A `SequentialLanguageModelDistribution` wrapping `RandomWalk` which 124 | implements PyTorch's `Distribution` interface. Language models now work 125 | with estimators! 126 | - A new documentation page on how to deal with all of that. 127 | - Fixed bug in controller that always compared thresholds against best, not the 128 | last point that reset the countdown (#55) 129 | - Added `pad_variable` and `RandomShift` (#54) 130 | - Modified `error_rate`, `prefix_error_rates` to actually compute error rates 131 | when non-default costs are used. Old functionality is now in `edit_distance` 132 | and `prefix_edit_distances` (#51) 133 | - Fixed bug in how padding is handled in string matching utilities. 134 | - Fixed logic errors in `compute-torch-token-data-dir-error-rates` (#50) 135 | - Modified frame end in `pydrobert.torch.data.transcript_to_token` and added 136 | some notes on the ambiguity of the conversion. 137 | - Added some more checks and a 'fix' flag to 138 | `pydrobert.torch.data.validate_spect_data_set`. Entry 139 | `get-torch-spect-data-dir-info` now has `--fix` flag, too. 140 | 141 | ## v0.3.0 142 | 143 | A considerable amount of refactoring occurred for this build, chiefly to get 144 | rid of Python 2.7 support. While the functionality did not change much for this 145 | version, we have switched from a `pkgutil`-style `pydrobert` namespace to 146 | PEP-420-style namespaces. As a result, *this package is not 147 | backwards-compatible with previous `pydrobert` packages!* Make sure that if any 148 | of the following are installed, they exceed the following version thresholds: 149 | 150 | - `pydrobert-param >0.2.0` 151 | - `pydrobert-kaldi >0.5.3` 152 | - `pydrobert-speech >0.1.0` 153 | 154 | Miscellaneous other stuff: 155 | 156 | - Type hints everywhere 157 | - Shifted python source to `src/` 158 | - Black-formatted remaining source 159 | - Removed `future` dependency 160 | - Shifted most of the configuration to `setup.cfg`, leaving only a shell 161 | in `setup.py` to remain compatible with Conda builds 162 | - Added `pyproject.toml` for [PEP 163 | 517](https://www.python.org/dev/peps/pep-0517/). 164 | - `tox.ini` for TOX testing 165 | - Switched to AppVeyor for CI 166 | - Added changelog :D 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE_nltk.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2001-2011 NLTK Project 2 | 3 | Licensed under the Apache License, Version 2.0 (the 'License'); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an 'AS IS' BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /LICENSE_pytorch.txt: -------------------------------------------------------------------------------- 1 | From PyTorch: 2 | 3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 8 | Copyright (c) 2011-2013 NYU (Clement Farabet) 9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 12 | 13 | From Caffe2: 14 | 15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 16 | 17 | All contributions by Facebook: 18 | Copyright (c) 2016 Facebook Inc. 19 | 20 | All contributions by Google: 21 | Copyright (c) 2015 Google Inc. 22 | All rights reserved. 23 | 24 | All contributions by Yangqing Jia: 25 | Copyright (c) 2015 Yangqing Jia 26 | All rights reserved. 27 | 28 | All contributions by Kakao Brain: 29 | Copyright 2019-2020 Kakao Brain 30 | 31 | All contributions from Caffe: 32 | Copyright(c) 2013, 2014, 2015, the respective contributors 33 | All rights reserved. 34 | 35 | All other contributions: 36 | Copyright(c) 2015, 2016 the respective contributors 37 | All rights reserved. 38 | 39 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 40 | copyright over their contributions to Caffe2. The project versioning records 41 | all such contribution and copyright details. If a contributor wants to further 42 | mark their specific copyright on a particular contribution, they should 43 | indicate their copyright solely in the commit message of the change when it is 44 | committed. 45 | 46 | All rights reserved. 47 | 48 | Redistribution and use in source and binary forms, with or without 49 | modification, are permitted provided that the following conditions are met: 50 | 51 | 1. Redistributions of source code must retain the above copyright 52 | notice, this list of conditions and the following disclaimer. 53 | 54 | 2. Redistributions in binary form must reproduce the above copyright 55 | notice, this list of conditions and the following disclaimer in the 56 | documentation and/or other materials provided with the distribution. 57 | 58 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 59 | and IDIAP Research Institute nor the names of its contributors may be 60 | used to endorse or promote products derived from this software without 61 | specific prior written permission. 62 | 63 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 64 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 65 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 66 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 67 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 68 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 69 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 70 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 71 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 72 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 73 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![tox](https://github.com/sdrobert/pydrobert-pytorch/actions/workflows/tox.yml/badge.svg)](https://github.com/sdrobert/pydrobert-pytorch/actions/workflows/tox.yml) 2 | [![Documentation Status](https://readthedocs.org/projects/pydrobert-pytorch/badge/?version=latest)](https://pydrobert-pytorch.readthedocs.io/en/latest/?badge=latest) 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | 5 | # pydrobert-pytorch 6 | 7 | PyTorch utilities for Machine Learning. This is an eclectic mix of utilities 8 | that I've used in my various projects. There is a definite leaning towards 9 | speech, specifically end-to-end ASR. The primary benefit `pydrobert-pytorch` 10 | has over other packages is modularity: you can pick and choose the 11 | functionality you desire without subscribing to an entire ecosystem. You can 12 | find out more about what the package offers in the documentation links below. 13 | 14 | **This is student-driven code, so don't expect a stable API. I'll try to use 15 | semantic versioning, but the best way to keep functionality stable is by 16 | pinning the version in the requirements or by forking.** 17 | 18 | ## Documentation 19 | 20 | - [Latest](https://pydrobert-pytorch.readthedocs.io/en/latest/) 21 | - [v0.3.0](https://pydrobert-pytorch.readthedocs.io/en/v0.3.0/) 22 | 23 | ## Installation 24 | 25 | `pydrobert-pytorch` is available through both Conda and PyPI. 26 | 27 | ``` bash 28 | conda install -c sdrobert pydrobert-pytorch 29 | pip install pydrobert-pytorch 30 | ``` 31 | 32 | ## Licensing and How to Cite 33 | 34 | Please see the [pydrobert page](https://github.com/sdrobert/pydrobert) for more 35 | details on how to cite this package. 36 | 37 | Implementations of 38 | `pydrobert.torch._img.{polyharmonic_spline,sparse_image_warp}` are based off 39 | Tensorflow's codebase, which is Apache 2.0 licensed. 40 | 41 | Implementations of 42 | `pydrobert.torch._compat.{broadcast_shapes,TorchVersion,one_hot}` were directly 43 | taken from the PyTorch codebase. A number of methods and functions in 44 | `pydrobert.torch._straight_through` modify PyTorch code (see the file for more 45 | info). PyTorch has a BSD-style license which can be found in the file 46 | `LICENSE_pytorch.txt`. 47 | 48 | The implementation of `pydrobert.torch._compat.check_methods` was taken 49 | directly from the CPython codebase, Copyright 2007 Google with additional 50 | notices at . 51 | 52 | The file `pydrobert.torch._textgrid,py` was taken with some minor modifications 53 | from 54 | [nltk_contrib](https://github.com/nltk/nltk_contrib/blob/95d1806e2f4e89e960b76a685b1fba2eaa7d5142/nltk_contrib/textgrid.py#L1). 55 | It is Apache 2.0-licensed, with the specific license text saved to 56 | `LICENSE_nltk.txt`. 57 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/cli_helper.py: -------------------------------------------------------------------------------- 1 | # Generate CLI page 2 | # Needs imports, so run with full 3 | 4 | import pydrobert.torch.command_line as cli 5 | import os 6 | from io import StringIO 7 | import sys 8 | import inspect 9 | import warnings 10 | 11 | warnings.simplefilter("ignore") 12 | 13 | # Modified from 14 | # https://stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call 15 | class Capturing(list): 16 | def __enter__(self): 17 | self._stdout = sys.stdout 18 | self._stderr = sys.stderr 19 | sys.stdout = sys.stderr = self._stringio = StringIO() 20 | return self 21 | 22 | def __exit__(self, *args): 23 | self.extend(self._stringio.getvalue().splitlines()) 24 | del self._stringio # free up some memory 25 | sys.stdout = self._stdout 26 | sys.stderr = self._stderr 27 | 28 | 29 | DIR = os.path.dirname(__file__) 30 | CLI_RST = os.path.join(DIR, "source", "cli.rst") 31 | 32 | buff = "Command-Line Interface\n======================\n\n" 33 | for cmd_name in sorted( 34 | ( 35 | "chunk-torch-spect-data-dir", 36 | "compute-mvn-stats-for-torch-feat-data-dir", 37 | "compute-torch-token-data-dir-error-rates", 38 | "ctm-to-torch-token-data-dir", 39 | "get-torch-spect-data-dir-info", 40 | "print-torch-ali-data-dir-length-moments", 41 | "print-torch-ref-data-dir-length-moments", 42 | "subset-torch-spect-data-dir", 43 | "textgrids-to-torch-token-data-dir", 44 | "torch-ali-data-dir-to-torch-token-data-dir", 45 | "torch-spect-data-dir-to-wds", 46 | "torch-token-data-dir-to-ctm", 47 | "torch-token-data-dir-to-textgrids", 48 | "torch-token-data-dir-to-torch-ali-data-dir", 49 | "torch-token-data-dir-to-trn", 50 | "trn-to-torch-token-data-dir", 51 | ) 52 | ): 53 | buff += cmd_name + "\n" + ("-" * len(cmd_name)) + "\n\n::\n\n " 54 | sys.argv[0] = cmd_name 55 | func = next( 56 | x[1] for x in inspect.getmembers(cli) if x[0] == cmd_name.replace("-", "_") 57 | ) 58 | with Capturing() as c: 59 | try: 60 | func(["-h"]) 61 | except SystemExit: 62 | pass 63 | buff += "\n ".join(c) + "\n\n" 64 | 65 | with open(CLI_RST, "w") as f: 66 | f.write(buff) 67 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | docutils>=0.17 2 | myst-parser 3 | param 4 | pydrobert-param>=0.4.0 5 | sphinx-rtd-theme>=0.5 6 | sphinx-autodoc-typehints 7 | Sphinx>=4.4 8 | typing_extensions 9 | numpy 10 | lightning 11 | -f https://download.pytorch.org/whl/cpu/torch_stable.html 12 | torch -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch.rst: -------------------------------------------------------------------------------- 1 | pydrobert.torch 2 | =============== 3 | 4 | .. toctree:: 5 | :glob: 6 | :maxdepth: 1 7 | 8 | torch/* -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/argcheck.rst: -------------------------------------------------------------------------------- 1 | argcheck 2 | ======== 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.argcheck 7 | 8 | .. autofunction:: as_bool 9 | .. autofunction:: as_closed01 10 | .. autofunction:: as_dir 11 | .. autofunction:: as_file 12 | .. autofunction:: as_float 13 | .. autofunction:: as_int 14 | .. autofunction:: as_nat 15 | .. autofunction:: as_negf 16 | .. autofunction:: as_negi 17 | .. autofunction:: as_nonnegf 18 | .. autofunction:: as_nonnegi 19 | .. autofunction:: as_nonposf 20 | .. autofunction:: as_nonposi 21 | .. autofunction:: as_open01 22 | .. autofunction:: as_path_dir 23 | .. autofunction:: as_path_file 24 | .. autofunction:: as_path 25 | .. autofunction:: as_posf 26 | .. autofunction:: as_posi 27 | .. autofunction:: as_str 28 | .. autofunction:: as_tensor 29 | .. autofunction:: has_ndim 30 | .. autofunction:: is_a 31 | .. autofunction:: is_bool 32 | .. autofunction:: is_btw_closed 33 | .. autofunction:: is_btw_closedf 34 | .. autofunction:: is_btw_closedi 35 | .. autofunction:: is_btw_closedt 36 | .. autofunction:: is_btw_open 37 | .. autofunction:: is_btw_openf 38 | .. autofunction:: is_btw_openi 39 | .. autofunction:: is_btw_opent 40 | .. autofunction:: is_btw 41 | .. autofunction:: is_btwf 42 | .. autofunction:: is_btwi 43 | .. autofunction:: is_btwt 44 | .. autofunction:: is_closed01 45 | .. autofunction:: is_closed01f 46 | .. autofunction:: is_closed01i 47 | .. autofunction:: is_closed01t 48 | .. autofunction:: is_dir 49 | .. autofunction:: is_equal 50 | .. autofunction:: is_equalf 51 | .. autofunction:: is_equali 52 | .. autofunction:: is_equalt 53 | .. autofunction:: is_exactly 54 | .. autofunction:: is_file 55 | .. autofunction:: is_float 56 | .. autofunction:: is_gt 57 | .. autofunction:: is_gte 58 | .. autofunction:: is_gtef 59 | .. autofunction:: is_gtei 60 | .. autofunction:: is_gtet 61 | .. autofunction:: is_gtf 62 | .. autofunction:: is_gti 63 | .. autofunction:: is_gtt 64 | .. autofunction:: is_in 65 | .. autofunction:: is_int 66 | .. autofunction:: is_lt 67 | .. autofunction:: is_lte 68 | .. autofunction:: is_ltef 69 | .. autofunction:: is_ltei 70 | .. autofunction:: is_ltet 71 | .. autofunction:: is_ltf 72 | .. autofunction:: is_lti 73 | .. autofunction:: is_ltt 74 | .. autofunction:: is_nat 75 | .. autofunction:: is_neg 76 | .. autofunction:: is_negf 77 | .. autofunction:: is_negi 78 | .. autofunction:: is_negt 79 | .. autofunction:: is_nonempty 80 | .. autofunction:: is_nonneg 81 | .. autofunction:: is_nonnegf 82 | .. autofunction:: is_nonnegi 83 | .. autofunction:: is_nonnegt 84 | .. autofunction:: is_nonpos 85 | .. autofunction:: is_nonposf 86 | .. autofunction:: is_nonposi 87 | .. autofunction:: is_nonpost 88 | .. autofunction:: is_numlike 89 | .. autofunction:: is_open01 90 | .. autofunction:: is_open01f 91 | .. autofunction:: is_open01i 92 | .. autofunction:: is_open01t 93 | .. autofunction:: is_path 94 | .. autofunction:: is_pos 95 | .. autofunction:: is_posf 96 | .. autofunction:: is_posi 97 | .. autofunction:: is_post 98 | .. autofunction:: is_str 99 | .. autofunction:: is_tensor 100 | .. autofunction:: is_token 101 | -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/config.rst: -------------------------------------------------------------------------------- 1 | config 2 | ====== 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.config 7 | :members: 8 | 9 | -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/data.rst: -------------------------------------------------------------------------------- 1 | data 2 | ==== 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.data 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/distributions.rst: -------------------------------------------------------------------------------- 1 | distributions 2 | ============= 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.distributions 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/estimators.rst: -------------------------------------------------------------------------------- 1 | estimators 2 | ========== 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.estimators 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/functional.rst: -------------------------------------------------------------------------------- 1 | functional 2 | ========== 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.functional 7 | 8 | Combinatorics 9 | ------------- 10 | .. autofunction:: binomial_coefficient 11 | .. autofunction:: enumerate_vocab_sequences 12 | .. autofunction:: enumerate_binary_sequences 13 | .. autofunction:: enumerate_binary_sequences_with_cardinality 14 | .. autofunction:: simple_random_sampling_without_replacement 15 | 16 | Decoding 17 | -------- 18 | .. autofunction:: beam_search_advance 19 | .. autofunction:: ctc_greedy_search 20 | .. autofunction:: ctc_prefix_search_advance 21 | .. autofunction:: random_walk_advance 22 | .. autofunction:: sequence_log_probs 23 | 24 | Features 25 | -------- 26 | .. autofunction:: chunk_by_slices 27 | .. autofunction:: chunk_token_sequences_by_slices 28 | .. autofunction:: dense_image_warp 29 | .. autofunction:: feat_deltas 30 | .. autofunction:: mean_var_norm 31 | .. autofunction:: pad_masked_sequence 32 | .. autofunction:: pad_variable 33 | .. autofunction:: polyharmonic_spline 34 | .. autofunction:: random_shift 35 | .. autofunction:: slice_spect_data 36 | .. autofunction:: sparse_image_warp 37 | .. autofunction:: spec_augment 38 | .. autofunction:: spec_augment_apply_parameters 39 | .. autofunction:: spec_augment_draw_parameters 40 | .. autofunction:: warp_1d_grid 41 | 42 | Reinforcement Learning 43 | ---------------------- 44 | .. autofunction:: time_distributed_return 45 | 46 | String Matching 47 | --------------- 48 | .. autofunction:: edit_distance 49 | .. autofunction:: error_rate 50 | .. autofunction:: fill_after_eos 51 | .. autofunction:: hard_optimal_completion_distillation_loss 52 | .. autofunction:: minimum_error_rate_loss 53 | .. autofunction:: optimal_completion 54 | .. autofunction:: prefix_edit_distances 55 | .. autofunction:: prefix_error_rates -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/lightning.rst: -------------------------------------------------------------------------------- 1 | lightning 2 | ========= 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.lightning 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/modules.rst: -------------------------------------------------------------------------------- 1 | modules 2 | ======= 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.modules 7 | 8 | Attention 9 | --------- 10 | .. autoclass:: GlobalSoftAttention 11 | .. autoclass:: ConcatSoftAttention 12 | .. autoclass:: DotProductSoftAttention 13 | .. autoclass:: GeneralizedDotProductSoftAttention 14 | .. autoclass:: MultiHeadedAttention 15 | 16 | Decoding 17 | -------- 18 | .. autoclass:: BeamSearch 19 | .. autoclass:: CTCGreedySearch 20 | .. autoclass:: CTCPrefixSearch 21 | .. autoclass:: RandomWalk 22 | .. autoclass:: SequenceLogProbabilities 23 | 24 | Features 25 | -------- 26 | .. autoclass:: ChunkBySlices 27 | .. autoclass:: ChunkTokenSequencesBySlices 28 | .. autoclass:: DenseImageWarp 29 | .. autoclass:: FeatureDeltas 30 | .. autoclass:: MeanVarianceNormalization 31 | .. autoclass:: PadMaskedSequence 32 | .. autoclass:: PadVariable 33 | .. autoclass:: PolyharmonicSpline 34 | .. autoclass:: RandomShift 35 | .. autoclass:: SliceSpectData 36 | .. autoclass:: SparseImageWarp 37 | .. autoclass:: SpecAugment 38 | .. autoclass:: Warp1DGrid 39 | 40 | Language Models 41 | --------------- 42 | .. autoclass:: ExtractableSequentialLanguageModel 43 | .. autoclass:: MixableSequentialLanguageModel 44 | .. autoclass:: SequentialLanguageModel 45 | .. autoclass:: ExtractableShallowFusionLanguageModel 46 | .. autoclass:: LookupLanguageModel 47 | .. autoclass:: MixableShallowFusionLanguageModel 48 | .. autoclass:: ShallowFusionLanguageModel 49 | 50 | Reinforcement Learning 51 | ---------------------- 52 | .. autoclass:: GumbelOneHotCategoricalRebarControlVariate 53 | .. autoclass:: LogisticBernoulliRebarControlVariate 54 | .. autoclass:: TimeDistributedReturn 55 | 56 | String Matching 57 | --------------- 58 | .. autoclass:: EditDistance 59 | .. autoclass:: ErrorRate 60 | .. autoclass:: FillAfterEndOfSequence 61 | .. autoclass:: HardOptimalCompletionDistillationLoss 62 | .. autoclass:: MinimumErrorRateLoss 63 | .. autoclass:: OptimalCompletion 64 | .. autoclass:: PrefixEditDistances 65 | .. autoclass:: PrefixErrorRates -------------------------------------------------------------------------------- /docs/source/api/pydrobert/torch/training.rst: -------------------------------------------------------------------------------- 1 | training 2 | ======== 3 | 4 | .. toctree:: 5 | 6 | .. automodule:: pydrobert.torch.training 7 | :members: -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | import param 16 | 17 | param.parameterized.docstring_signature = False 18 | param.parameterized.docstring_describe_params = False 19 | 20 | sys.path.insert(0, os.path.abspath("../../src")) 21 | 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project = "pydrobert-pytorch" 26 | copyright = "2023, Sean Robertson" 27 | author = "Sean Robertson" 28 | 29 | language = "en" 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "myst_parser", 38 | "sphinx.ext.autodoc", 39 | "sphinx.ext.intersphinx", 40 | "sphinx.ext.viewcode", 41 | "sphinx.ext.napoleon", 42 | "sphinx_autodoc_typehints", 43 | "sphinx_rtd_theme", 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # List of patterns, relative to source directory, that match files and 50 | # directories to ignore when looking for source files. 51 | # This pattern also affects html_static_path and html_extra_path. 52 | exclude_patterns = [] 53 | 54 | napoleon_numpy_docstring = True 55 | napoleon_google_docstring = False 56 | napoleon_include_init_with_doc = True 57 | autodoc_typehints = "none" 58 | autodoc_type_aliases = napoleon_type_aliases = { 59 | "np.ndarray": "numpy.ndarray", 60 | "Literal": "typing.Literal", 61 | } 62 | autodoc_inherit_docstrings = False 63 | napoleon_preprocess_types = True 64 | typehints_document_rtype = False 65 | napoleon_use_rtype = False 66 | napoleon_custom_sections = [ 67 | ("Call Parameters", "returns_style"), 68 | ("Variables", "returns_style"), 69 | ] 70 | 71 | intersphinx_mapping = { 72 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 73 | "param": ("https://param.holoviz.org/", None), 74 | "pydrobert.kaldi": ("https://pydrobert-kaldi.readthedocs.io/en/latest", None), 75 | "pydrobert.param": ("https://pydrobert-param.readthedocs.io/en/latest", None), 76 | "pydrobert.speech": ("https://pydrobert-speech.readthedocs.io/en/latest", None), 77 | "python": ("https://docs.python.org/", None), 78 | "pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None), 79 | "torch": ("https://pytorch.org/docs/stable/", None), 80 | } 81 | 82 | # -- Options for HTML output ------------------------------------------------- 83 | 84 | # on_rtd = os.environ.get("READTHEDOCS") == "True" 85 | # if on_rtd: 86 | # html_theme = "default" 87 | # else: 88 | html_theme = "sphinx_rtd_theme" 89 | 90 | # Add any paths that contain custom static files (such as style sheets) here, 91 | # relative to this directory. They are copied after the builtin static files, 92 | # so a file named "default.css" will overwrite the builtin "default.css". 93 | # html_static_path = ["_static"] 94 | 95 | highlight_language = "none" 96 | 97 | master_doc = "index" 98 | 99 | html_context = { 100 | "display_github": True, # Integrate GitHub 101 | "github_user": "sdrobert", # Username 102 | "github_repo": "pydrobert-pytorch", # Repo name 103 | "github_version": "master", # Version 104 | "conf_py_path": "/docs/source/", # Path in the checkout to the docs root 105 | } 106 | 107 | 108 | # def docstring_handler(app, what, name, obj, options, lines): 109 | # if "Params" in name.split(".")[-1]: 110 | # try: 111 | # pdict = obj.param.objects(instance=False) 112 | # except: 113 | # return 114 | # del pdict["name"] 115 | # new_lines = [] 116 | # for name, p in pdict.items(): 117 | # doc = p.doc 118 | # deft = p.default 119 | # bounds = p.bounds if hasattr(p, "bounds") else None 120 | # new_lines.append( 121 | # "- **{}**: {}. *default={}{}*".format( 122 | # name, doc, deft, ", bounds={}".format(bounds) if bounds else "" 123 | # ) 124 | # ) 125 | # new_lines.append("") 126 | # new_lines.append("") 127 | # if new_lines: 128 | # new_lines.insert(0, "") 129 | # new_lines.insert(0, "") 130 | # new_lines.insert(1, "**Parameters**") 131 | # new_lines.insert(2, "") 132 | # new_lines.insert(2, "") 133 | # lines += new_lines 134 | # options["undoc-members"] = False 135 | 136 | 137 | # def preprocess_signature(app, obj, bound_method): 138 | # import inspect 139 | 140 | # print(obj, inspect.signature(obj)) 141 | 142 | 143 | # def setup(app): 144 | # # app.connect("autodoc-before-process-signature", preprocess_signature) 145 | # app.connect("autodoc-process-docstring", docstring_handler) 146 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../../README.md 2 | :parser: myst_parser.sphinx_ 3 | 4 | .. toctree:: 5 | :glob: 6 | :caption: Tutorials 7 | 8 | tutorials/* 9 | 10 | .. toctree:: 11 | :caption: API 12 | :maxdepth: 2 13 | 14 | api/pydrobert/torch 15 | 16 | .. toctree:: 17 | :caption: Other 18 | 19 | cli 20 | references 21 | 22 | Indices and tables 23 | ================== 24 | 25 | * :ref:`genindex` 26 | * :ref:`modindex` 27 | * :ref:`search` 28 | -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | .. [fan1962] C. T. Fan, M. E. Muller, and I. Rezucha, "Development of sampling 5 | plans by using sequential (item by item) selection techniques and digital 6 | computers," vol. 57, no. 298, pp. 387-402, Jun. 1962, doi: 7 | 10.1080/01621459.1962.10480667. 8 | .. [howard1972] S. Howard, "Discussion on Professor Cox's paper," Journal of 9 | the Royal Statistical Society, vol. 34, no. 2, pp. 210-211, Jan. 1972, doi: 10 | 10.1111/j.2517-6161.1972.tb00900.x. 11 | .. [williams1992] R. J. Williams, "Simple statistical gradient-following 12 | algorithms for connectionist reinforcement learning," Machine Learning, 13 | vol. 8, no. 3, pp. 229-256, May 1992. 14 | .. [chen1994] X.-H. Chen, A. P. Dempster, and J. S. Liu, "Weighted finite 15 | population sampling to maximize entropy," Biometrika, vol. 81, no. 3, pp. 16 | 457-69, 1994, doi: 10.2307/2337119. 17 | .. [mengerson1996] K. L. Mengersen and R. L. Tweedie, "Rates of convergence of 18 | the Hastings and Metropolis algorithms," The Annals of Statistics, vol. 24, 19 | no. 1, pp. 101-121, Feb. 1996, doi: 10.1214/aos/1033066201. 20 | .. [graves2006] A. Graves, S. Fernández, F. Gomez, and J. Schmidhuber, 21 | "Connectionist Temporal Classification: Labelling unsegmented sequence data 22 | with recurrent neural networks," New York, NY, USA, 2006, pp. 369-376. 23 | doi: 10.1145/1143844.1143891. 24 | .. [mikolov2010] T. Mikolov, M. Karafiát, L. Burget, J. Černocký, and S. 25 | Khudanpur, "Recurrent neural network based language model," presented at 26 | Interspeech, Makuhari, Japan, 2010. 27 | .. [heafield2011] K. Heafield, "KenLM: Faster and smaller language model 28 | queries," in Proceedings of the Sixth Workshop on Statistical Machine 29 | Translation, Edinburgh, Scotland, 2011, pp. 187-197. 30 | .. [bengio2013] Y. Bengio, N. Léonard, and A. C. Courville, "Estimating or 31 | Propagating Gradients Through Stochastic Neurons for Conditional 32 | Computation.," CoRR, vol. abs/1308.3432, 2013, [Online]. Available: 33 | http://arxiv.org/abs/1308.3432 34 | .. [cho2014] K. Cho et al., "Learning phrase representations using RNN 35 | Encoder-Decoder for Statistical Machine Translation," Doha, Qatar, 2014, 36 | pp. 1724--1734. [Online]. Available: 37 | https://www.aclweb.org/anthology/D14-1179 38 | .. [bahdanau2015] D. Bahdanau, K. Cho, and Y. Bengio, "Neural machine 39 | translation by jointly learning to align and translate.," in 3rd 40 | International Conference on Learning Representations, ICLR 2015, San Diego, 41 | CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015. 42 | .. .. [burda2015] Y. Burda, R. B. Grosse, and R. Salakhutdinov, "Importance 43 | .. weighted autoencoders," San Juan, Puerto Rico, 2016. [Online]. 44 | .. Available: http://arxiv.org/abs/1509.00519 45 | .. [gulcehre2015] Ç. Gülçehre et al., "On using monolingual corpora in neural 46 | machine translation," CoRR, vol. abs/1503.03535, 2015, [Online]. Available: 47 | http://arxiv.org/abs/1503.03535 48 | .. [luong2015] T. Luong, H. Pham, and C. D. Manning, "Effective approaches to 49 | attention-based neural machine translation," in Proceedings of the 2015 50 | Conference on Empirical Methods in Natural Language Processing, Lisbon, 51 | Portugal, 2015, pp. 1412-1421. 52 | .. [chan2016] W. Chan, N. Jaitly, Q. V. Le, and O. Vinyals, "Listen, Attend and 53 | Spell: A neural network for Large Vocabulary Conversational Speech 54 | Recognition," Mar. 2016, pp. 4960-4964. doi: 10.1109/ICASSP.2016.7472621. 55 | .. [grathwohl2017] W. Grathwohl, D. Choi, Y. Wu, G. Roeder, and D. K. Duvenaud, 56 | "Backpropagation through the Void: Optimizing control variates for 57 | black-box gradient estimation," CoRR, vol. abs/1711.00123, 2017. 58 | .. [vaswani2017] A. Vaswani et al., "Attention is all you beed," in Advances in 59 | Neural Information Processing Systems 30, I. Guyon, U. V. Luxburg, S. 60 | Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, Eds. Curran 61 | Associates, Inc., 2017, pp. 5998-6008. 62 | .. [tucker2017] G. Tucker, A. Mnih, C. J. Maddison, J. Lawson, and J. 63 | Sohl-Dickstein, "REBAR: Low-variance, unbiased gradient estimates for 64 | discrete latent variable models," in Advances in Neural Information 65 | Processing Systems 30, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, 66 | R. Fergus, S. Vishwanathan, and R. Garnett, Eds. Curran Associates, 67 | Inc., 2017, pp. 2627-2636. 68 | .. [prabhavalkar2018] R. Prabhavalkar et al., "Minimum Word Error Rate Training 69 | for Attention-Based Sequence-to-Sequence Models," presented at the IEEE 70 | International Conference on Acoustics, Speech and Signal Processing 71 | (ICASSP), 2018, pp. 4839-4843. 72 | .. [sabour2018] S. Sabour, W. Chan, and M. Norouzi, "Optimal Completion 73 | Distillation for Sequence Learning," CoRR, vol. abs/1810.01398, 2018. 74 | .. [bert2019] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, "BERT: 75 | Pre-training of deep bidirectional Transformers for language understanding," 76 | Minneapolis, USA, 2019, vol. 1, pp. 4171-4186. [Online]. 77 | Available: https://aclweb.org/anthology/papers/N/N19/N19-1423/ 78 | .. [park2019] D. S. Park et al., "SpecAugment: A simple data augmentation 79 | method for automatic speech recognition," in Proc. Interspeech, 2019, pp. 80 | 2613-2617, doi: 10.21437/Interspeech.2019-2680. 81 | .. [park2020] D. S. Park et al., "Specaugment on large scale datasets," May 82 | 2020, pp. 6879-6883, doi: 10.1109/ICASSP40776.2020.9053205. 83 | 84 | -------------------------------------------------------------------------------- /docs/source/tutorials/advanced-attn.rst: -------------------------------------------------------------------------------- 1 | .. _advanced-attn: 2 | 3 | Attention and Transformer Networks 4 | ================================== 5 | 6 | This document is a supplement for advanced usage of 7 | :class:`pydrobert.torch.modules.GlobalSoftAttention`, such as for Transformer 8 | Networks [vaswani2017]_. It picks up where the class' summary left off. 9 | 10 | `query` is an (n - 1)-dimensional tensor for ``n > 1``. `key` is an 11 | n-dimensional tensor, and `value` is some n-dimensional tensor. Letting 12 | :math:`t` index the `dim`-th dimension of `key`, :math:`q` index the last 13 | dimension of `query`, and :math:`k` index the last index of `key`. Let 14 | :math:`query_{t=0}` indicate the "unsqueezed" version of `query` where 15 | :math:`t` is inserted as the `dim`-th dimension. Then :math:`query_{t=0,q}` 16 | must `broadcast 17 | `__ 18 | with :math:`key_k`. If specified, `mask` is an (n - 1)-dimensional tensor that 19 | broadcasts with :math:`e`, that is, broadcast with a tensor of the same shape 20 | as :math:`key_k` after it has been broadcast to :math:`query_{t=0,q}`. Finally, 21 | `value` must broadcast with :math:`a_{k=0}`, that is, :math:`a` with an 22 | unsqueezed final dimension. Care should be taken to ensure that any added 23 | dimensions to `query`, `key`, and `value` ensure that the dimension that is to 24 | be attended to (reduced) broadcasts to the correct location. 25 | 26 | We'll illustrate with an example. Here, we've designed a barebones version of a 27 | transformer network. There are lots of extra bits in a full transformer network 28 | -- check [vaswani2017]_. Here we focus on the single-headed attention mechanism 29 | (though a multi-headed version would be trivial to implement with 30 | :class:`pydrobert.torch.modules.MultiHeadedAttention`). You can probably skip 31 | the explanation in the middle if all you want to make is a transformer network 32 | -- these settings should work. 33 | 34 | First the requisite imports: 35 | 36 | >>> import torch 37 | >>> from pydrobert.torch.modules import * 38 | 39 | The encoder is going to take in transcripts `inp` of shape ``(T, num_batch)``, 40 | which have been right-padded along dimension 0. It will output both its 41 | encoding in the shape ``(T, num_batch, model_size)`` and a mask of shape ``(T, 42 | 1, num_batch)`` that will be used by the decoder to only consider the region of 43 | the encoding that was unpadded. By not specifying `dim` when initializing 44 | :class:`pydrobert.torch.modules.DotProductSoftAttention`, the attention 45 | dimension is implicitly set to 0, which turns out to be our sequence dimension. 46 | 47 | >>> class Encoder(torch.nn.Module): 48 | >>> def __init__(self, model_size, num_classes, padding_idx=-1): 49 | >>> super(Encoder, self).__init__() 50 | >>> self.model_size = model_size 51 | >>> self.num_classes = num_classes 52 | >>> self.embedder = torch.nn.Embedding( 53 | >>> num_classes, model_size, padding_idx=padding_idx) 54 | >>> self.attention = DotProductSoftAttention( 55 | >>> model_size, scale_factor=model_size ** -.5) 56 | >>> 57 | >>> def forward(self, inp): 58 | >>> embedding = self.embedder(inp) 59 | >>> query = embedding # (T, num_batch, model_size) 60 | >>> kv = embedding.unsqueeze(1) # (T, 1, num_batch, model_size) 61 | >>> mask = inp.ne(self.embedder.padding_idx) 62 | >>> enc_mask = mask.unsqueeze(1) 63 | >>> out = self.attention(query, kv, kv, enc_mask) 64 | >>> return out, mask.unsqueeze(1) 65 | 66 | The ``unsqueeze()`` calls are intended to ensure broadcasting occurs properly. 67 | We're going to reduce the 0-th dimension (of size ``T``) of `kv`, but the 0-th 68 | dimension of :math:`query_{t=0,q}` has to be accounted for when creating 69 | :math:`e`. Then, through broadcasting, we expect :math:`e` to be shaped as 70 | 71 | .. code-block:: none 72 | 73 | query_{t=0,q} 1 T num_batch 74 | key_k T 1 num_batch 75 | --------------------------------- 76 | e T T num_batch 77 | 78 | (The attention mechanism gets rid of the last dimension of `query` and `key`, 79 | in this case by taking the inner product). In :math:`e`, the 0-th dimension is 80 | going to refer to each index of the sequence in `key`, whereas the 1-st 81 | dimension refers to each index in the sequence of `value`. Effectively, a 82 | Cartesian Product has been produced between the sequence dimensions of both 83 | `query` and `key`. 84 | 85 | We've unsqueezed `mask` to have shape ``(T, 1, num_batch)``. `mask` is 86 | responsible for ensuring only non-padded values of `key` are considered. It 87 | broadcasts with :math:`e` as: 88 | 89 | .. code-block:: none 90 | 91 | mask T 1 num_batch 92 | e T T num_batch 93 | --------------------------------- 94 | e & mask T T num_batch 95 | 96 | Which means that the mask is being applied to the 0-th (`key` sequence) 97 | dimension and copied for every 1-st (`query` sequence) dimension. Had we 98 | instead unsqueezed the mask into shape ``(1, T, num_batch)``, the mask would 99 | have been applied to the 1-st dimension and copied to the 0-th instead. This 100 | mask would've introduced ``NaN`` into ``a[:, i]`` for some ``i``. 101 | 102 | Finally, `value` must broadcast with :math:`a_{k=0}`: 103 | 104 | .. code-block:: none 105 | 106 | a_{k=0} T T num_batch 107 | value T 1 num_batch 108 | --------------------------------- 109 | a_{k=0} * value T T num_batch 110 | 111 | The 0-th dimension of `value` corresponds to its sequence dimension, which is 112 | lined up with the `key` sequence dimension, which is the one to be attended to. 113 | Had `value` been shaped as ``(1, T, num_batch)``, its sequence value would line 114 | up with that of `query`, :math:`a_{k=0} * value` would be constant along the 115 | attention dimension, and the weighted combination of terms would just yield the 116 | original `value` tensor. 117 | 118 | Now on to the decoder 119 | 120 | >>> class Decoder(torch.nn.Module): 121 | >>> def __init__(self, model_size, num_classes, padding_idx=-2): 122 | >>> super(Decoder, self).__init__() 123 | >>> self.model_size = model_size 124 | >>> self.num_classes = num_classes 125 | >>> self.embedder = torch.nn.Embedding( 126 | >>> num_classes, model_size, padding_idx=padding_idx) 127 | >>> self.attention = DotProductSoftAttention( 128 | >>> model_size, scale_factor=model_size ** -.5) 129 | >>> self.ff = torch.nn.Linear(model_size, num_classes) 130 | >>> 131 | >>> def forward(self, enc_out, dec_in, enc_mask=None): 132 | >>> embedding = self.embedder(dec_in) 133 | >>> query = embedding # (S, num_batch, model_size) 134 | >>> kv = embedding.unsqueeze(1) # (S, 1, num_batch, model_size) 135 | >>> pad_mask = dec_in.ne(self.embedder.padding_idx) 136 | >>> pad_mask = pad_mask.unsqueeze(1) # (S, 1, num_batch) 137 | >>> auto_mask = torch.ones( 138 | >>> query.shape[0], query.shape[0], dtype=torch.uint8) 139 | >>> auto_mask = torch.triu(auto_mask) 140 | >>> auto_mask = auto_mask.unsqueeze(-1) # (S, S, 1) 141 | >>> dec_mask = pad_mask & auto_mask # (S, S, num_batch) 142 | >>> dec_out = self.attention(query, kv, kv, dec_mask) 143 | >>> query = dec_out # (S, num_batch, model_size) 144 | >>> kv = enc_out.unsqueeze(1) # (T, 1, num_batch, model_size) 145 | >>> out = self.attention(query, kv, kv, enc_mask) 146 | >>> out = self.ff(out) 147 | >>> return out, pad_mask 148 | 149 | You can follow a similar logic as from the encoder to figure out most of the 150 | sizes here. The only not-so-clear part is the self-attention mask for the 151 | decoder. `pad_mask` does the same job as the encoder's mask: it ensures only 152 | non-padded values are considered in the attention vector. `auto_mask` ensures 153 | the auto-regressive property of key-value computations. That is, letting 154 | :math:`s` index the sequence dimension of `dec_in`, we want :math:`out_s` not 155 | to depend on any :math:`dec\_in_{>s}`. Recall `query`, `key`, and `value` are 156 | all `dec_in`. Letting :math:`s` be the sequence dimension for `key` (dim=0, 157 | attended to), and :math:`s'` be the sequence dimension for `query` (dim=1, 158 | kept), we find the upper-triangular `auto_mask` satisfies 159 | 160 | .. math:: 161 | 162 | auto\_mask_{s,s'} = \begin{cases} 163 | 1 & \mbox{if } s \leq s' \\ 164 | 0 & \mbox{if } s > s' 165 | \end{cases} 166 | 167 | Since `auto_mask` should be applied indiscriminately to all batches, we 168 | unsqueeze a final dimension so that it broadcasts to the batch dimension of 169 | `pad_mask`. 170 | 171 | The rest is straightforward. Here is some prep for a random data set: 172 | 173 | >>> T, num_batch, model_size = 100, 5, 1000 174 | >>> num_classes, start, eos = 20, 0, 1 175 | >>> padding = num_classes - 1 176 | >>> inp_lens = torch.randint(1, T + 1, (num_batch,)) 177 | >>> inp = torch.nn.utils.rnn.pad_sequence( 178 | >>> [ 179 | >>> torch.randint(2, num_classes - 1, (x + 1,)) 180 | >>> for x in inp_lens 181 | >>> ], 182 | >>> padding_value=padding, 183 | >>> ) 184 | >>> inp[inp_lens, range(num_batch)] = eos 185 | >>> target_lens = torch.randint(1, T + 1, (num_batch,)) 186 | >>> y = torch.nn.utils.rnn.pad_sequence( 187 | >>> [ 188 | >>> torch.randint(2, num_classes - 1, (x + 2,)) 189 | >>> for x in target_lens 190 | >>> ], 191 | >>> padding_value=padding, 192 | >>> ) 193 | >>> y[0] = start 194 | >>> y[target_lens + 1, range(num_batch)] = eos 195 | >>> dec_inp, targets = y[:-1], y[1:] 196 | >>> encoder = Encoder(model_size, num_classes, padding_idx=padding) 197 | >>> decoder = Decoder(model_size, num_classes, padding_idx=padding) 198 | >>> loss = torch.nn.CrossEntropyLoss(ignore_index=padding) 199 | >>> optim = torch.optim.Adam( 200 | >>> list(encoder.parameters()) + list(decoder.parameters())) 201 | 202 | Here's training a batch (you'lll have to do this a whole lot of times to get 203 | it to converge) 204 | 205 | >>> optim.zero_grad() 206 | >>> enc_out, enc_mask = encoder(inp) 207 | >>> logits, _ = decoder(enc_out, dec_inp, enc_mask) 208 | >>> logits = logits[..., :-1] # get rid of padding logit 209 | >>> l = loss(logits.view(-1, num_classes - 1), targets.flatten()) 210 | >>> l.backward() 211 | >>> optim.step() 212 | 213 | And finally, decoding a batch (test time) using greedy search 214 | 215 | >>> enc_out, enc_mask = encoder(inp) 216 | >>> dec_hyp = torch.full((1, num_batch), start, dtype=torch.long) 217 | >>> enc_out, enc_mask = encoder(inp) 218 | >>> done_mask = torch.zeros(num_batch, dtype=torch.uint8) 219 | >>> while not done_mask.all(): 220 | >>> logits, _ = decoder(enc_out, dec_hyp, enc_mask) 221 | >>> logits = logits[..., :-1] # get rid of padding logit 222 | >>> pred = logits[-1].argmax(1) 223 | >>> pred.masked_fill_(done_mask, eos) 224 | >>> done_mask = pred.eq(eos) 225 | >>> dec_hyp = torch.cat([dec_hyp, pred.unsqueeze(0)], 0) 226 | >>> dec_hyp = dec_hyp[1:] 227 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pydrobert-torch 2 | 3 | channels: 4 | - sdrobert 5 | - pytorch 6 | - conda-forge 7 | 8 | dependencies: 9 | - pytorch 10 | - python 11 | - pytorch-lightning>=1.7 12 | - pydrobert-param>=0.4.0 13 | - pytest 14 | - webdataset 15 | - pydrobert-speech>=0.2.0 16 | - ipython 17 | - sphinx 18 | - myst-parser 19 | - sphinx-autodoc-typehints 20 | - sphinx_rtd_theme 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools_scm] 6 | write_to = "src/pydrobert/torch/_version.py" 7 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | cpu : tests are on the cpu 4 | gpu : tests are on the gpu 5 | trace : tests involve tracing code (TorchScript) 6 | script : tests involve scripting code (TorchScript) 7 | nojit : tests could involve tracing or scripting, but these versions do not (TorchScript) 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pydrobert-pytorch 3 | description = "PyTorch utilities for ML, specifically speech" 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | license = Apache-2.0 7 | license_files = 8 | LICENSE 9 | LICENSE_pytorch.txt 10 | LICENCE_nltk.txt 11 | url = https://github.com/sdrobert/pydrobert-pytorch 12 | project_urls = 13 | Documentation = https://pydrobert-pytorch.readthedocs.io 14 | author = Sean Robertson 15 | author_email = sdrobert@cs.toronto.edu 16 | classifiers = 17 | Development Status :: 3 - Alpha 18 | License :: OSI Approved :: Apache Software License 19 | Programming Language :: Python :: 3 20 | 21 | [options] 22 | zip_safe = False 23 | packages = find_namespace: 24 | package_dir = 25 | = src 26 | python_requires = >= 3.6 27 | install_requires = 28 | numpy 29 | torch>=1.5.1 30 | param 31 | 32 | [options.entry_points] 33 | console_scripts = 34 | chunk-torch-spect-data-dir = pydrobert.torch.command_line:chunk_torch_spect_data_dir 35 | compute-mvn-stats-for-torch-feat-data-dir = pydrobert.torch.command_line:compute_mvn_stats_for_torch_feat_data_dir 36 | compute-torch-token-data-dir-error-rates = pydrobert.torch.command_line:compute_torch_token_data_dir_error_rates 37 | ctm-to-torch-token-data-dir = pydrobert.torch.command_line:ctm_to_torch_token_data_dir 38 | get-torch-spect-data-dir-info = pydrobert.torch.command_line:get_torch_spect_data_dir_info 39 | print-torch-ali-data-dir-length-moments = pydrobert.torch.command_line:print_torch_ali_data_dir_length_moments 40 | print-torch-ref-data-dir-length-moments = pydrobert.torch.command_line:print_torch_ref_data_dir_length_moments 41 | subset-torch-spect-data-dir = pydrobert.torch.command_line:subset_torch_spect_data_dir 42 | textgrids-to-torch-token-data-dir = pydrobert.torch.command_line:textgrids_to_torch_token_data_dir 43 | torch-ali-data-dir-to-torch-token-data-dir = pydrobert.torch.command_line:torch_ali_data_dir_to_torch_token_data_dir 44 | torch-spect-data-dir-to-wds = pydrobert.torch.command_line:torch_spect_data_dir_to_wds 45 | torch-token-data-dir-to-ctm = pydrobert.torch.command_line:torch_token_data_dir_to_ctm 46 | torch-token-data-dir-to-textgrids = pydrobert.torch.command_line:torch_token_data_dir_to_textgrids 47 | torch-token-data-dir-to-torch-ali-data-dir = pydrobert.torch.command_line:torch_token_data_dir_to_torch_ali_data_dir 48 | torch-token-data-dir-to-trn = pydrobert.torch.command_line:torch_token_data_dir_to_trn 49 | trn-to-torch-token-data-dir = pydrobert.torch.command_line:trn_to_torch_token_data_dir 50 | 51 | [options.packages.find] 52 | where = src 53 | 54 | [options.extras_require] 55 | lightning = 56 | pytorch_lightning>=1.7 57 | pydrobert-param[yaml]>=0.4.0 58 | torch>=1.10 59 | -------------------------------------------------------------------------------- /src/pydrobert/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __author__ = "Sean Robertson" 16 | __email__ = "sdrobert@cs.toronto.edu" 17 | __license__ = "Apache 2.0" 18 | __copyright__ = "Copyright 2022 Sean Robertson" 19 | 20 | try: 21 | from ._version import version as __version__ # type: ignore 22 | except ImportError: 23 | __version__ = "inplace" 24 | 25 | __all__ = [ 26 | "config", 27 | "data", 28 | "distributions", 29 | "estimators", 30 | "functional", 31 | "modules", 32 | "training", 33 | ] 34 | -------------------------------------------------------------------------------- /src/pydrobert/torch/_compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # Code for broadcast_shapes was adapted from PyTorch 4 | # https://github.com/pytorch/pytorch/blob/2367face24afb159f73ebf40dc6f23e46132b770/torch/functional.py 5 | # Code for TorchVersion was taken directly from PyTorch 6 | # https://github.com/pytorch/pytorch/blob/b737e09f60dd56dbae520e436648e1f3ebc1f937/torch/torch_version.py 7 | # Code for one_hot was taken directly from PyTorch. 8 | # https://github.com/pytorch/pytorch/blob/89c844db9b3120223bc4e45a1dcbb2368301e956/torch/distributions/constraints.py 9 | # See LICENSE_pytorch in project root directory for PyTorch license. 10 | # 11 | # Code for check_methods was taken directly from CPython 12 | # https://github.com/python/cpython/blob/2085bd0877e17ad4d98a4586d5eabb6faecbb190/Lib/_collections_abc.py 13 | # With the following PSF license 14 | # 15 | # Copyright 2007 Google, Inc. All Rights Reserved. 16 | # Licensed to PSF under a Contributor Agreement. 17 | # 18 | # with the additional notices 19 | # https://docs.python.org/3/copyright.html?highlight=copyright 20 | 21 | # Licensed under the Apache License, Version 2.0 (the "License"); 22 | # you may not use this file except in compliance with the License. 23 | # You may obtain a copy of the License at 24 | 25 | # http://www.apache.org/licenses/LICENSE-2.0 26 | 27 | # Unless required by applicable law or agreed to in writing, software 28 | # distributed under the License is distributed on an "AS IS" BASIS, 29 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | # See the License for the specific language governing permissions and 31 | # limitations under the License. 32 | 33 | from typing import ( 34 | Any, 35 | Iterable, 36 | List, 37 | Optional, 38 | Tuple, 39 | Union, 40 | NamedTuple, 41 | Set, 42 | ) 43 | 44 | import torch 45 | import torch.jit.annotations 46 | 47 | from . import config 48 | 49 | 50 | __all__ = [ 51 | "broadcast_shapes", 52 | "check_methods", 53 | "euler_constant", 54 | "jit_isinstance", 55 | "linalg_solve", 56 | "meshgrid", 57 | "one_hot", 58 | "pad_sequence", 59 | "script", 60 | "SpoofPackedSequence", 61 | "trunc_divide", 62 | ] 63 | 64 | 65 | def check_methods(C, *methods): 66 | try: 67 | mro = C.__mro__ 68 | for method in methods: 69 | for B in mro: 70 | if method in B.__dict__: 71 | if B.__dict__[method] is None: 72 | return NotImplemented 73 | break 74 | else: 75 | return NotImplemented 76 | except AttributeError: 77 | for method in methods: 78 | if getattr(C, method, None) is None: 79 | return NotImplemented 80 | return True 81 | 82 | 83 | # to avoid some scripting issues with torch.utils.nn.PackedSequence 84 | class SpoofPackedSequence(NamedTuple): 85 | data: torch.Tensor 86 | batch_sizes: torch.Tensor 87 | sorted_indices: Optional[torch.Tensor] 88 | unsorted_indices: Optional[torch.Tensor] 89 | 90 | 91 | try: 92 | from torch.torch_version import __version__ as _v # type: ignore 93 | except ModuleNotFoundError: 94 | from torch.version import __version__ as internal_version 95 | from pkg_resources import packaging # type: ignore[attr-defined] 96 | 97 | Version = packaging.version.Version 98 | InvalidVersion = packaging.version.InvalidVersion 99 | 100 | class TorchVersion(str): 101 | """A string with magic powers to compare to both Version and iterables! 102 | Prior to 1.10.0 torch.__version__ was stored as a str and so many did 103 | comparisons against torch.__version__ as if it were a str. In order to not 104 | break them we have TorchVersion which masquerades as a str while also 105 | having the ability to compare against both packaging.version.Version as 106 | well as tuples of values, eg. (1, 2, 1) 107 | Examples: 108 | Comparing a TorchVersion object to a Version object 109 | TorchVersion('1.10.0a') > Version('1.10.0a') 110 | Comparing a TorchVersion object to a Tuple object 111 | TorchVersion('1.10.0a') > (1, 2) # 1.2 112 | TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 113 | Comparing a TorchVersion object against a string 114 | TorchVersion('1.10.0a') > '1.2' 115 | TorchVersion('1.10.0a') > '1.2.1' 116 | """ 117 | 118 | # fully qualified type names here to appease mypy 119 | def _convert_to_version( 120 | self, inp: Union[packaging.version.Version, str, Iterable] 121 | ) -> packaging.version.Version: 122 | if isinstance(inp, Version): 123 | return inp 124 | elif isinstance(inp, str): 125 | return Version(inp) 126 | elif isinstance(inp, Iterable): 127 | # Ideally this should work for most cases by attempting to group 128 | # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) 129 | # Examples: 130 | # * (1) -> Version("1") 131 | # * (1, 20) -> Version("1.20") 132 | # * (1, 20, 1) -> Version("1.20.1") 133 | return Version(".".join((str(item) for item in inp))) 134 | else: 135 | raise InvalidVersion(inp) 136 | 137 | def __gt__(self, cmp): 138 | try: 139 | return Version(self).__gt__(self._convert_to_version(cmp)) 140 | except InvalidVersion: 141 | # Fall back to regular string comparison if dealing with an invalid 142 | # version like 'parrot' 143 | return super().__gt__(cmp) 144 | 145 | def __lt__(self, cmp): 146 | try: 147 | return Version(self).__lt__(self._convert_to_version(cmp)) 148 | except InvalidVersion: 149 | # Fall back to regular string comparison if dealing with an invalid 150 | # version like 'parrot' 151 | return super().__lt__(cmp) 152 | 153 | def __eq__(self, cmp): 154 | try: 155 | return Version(self).__eq__(self._convert_to_version(cmp)) 156 | except InvalidVersion: 157 | # Fall back to regular string comparison if dealing with an invalid 158 | # version like 'parrot' 159 | return super().__eq__(cmp) 160 | 161 | def __ge__(self, cmp): 162 | try: 163 | return Version(self).__ge__(self._convert_to_version(cmp)) 164 | except InvalidVersion: 165 | # Fall back to regular string comparison if dealing with an invalid 166 | # version like 'parrot' 167 | return super().__ge__(cmp) 168 | 169 | def __le__(self, cmp): 170 | try: 171 | return Version(self).__le__(self._convert_to_version(cmp)) 172 | except InvalidVersion: 173 | # Fall back to regular string comparison if dealing with an invalid 174 | # version like 'parrot' 175 | return super().__le__(cmp) 176 | 177 | _v = TorchVersion(internal_version) 178 | 179 | try: 180 | _v < "1.8.0" 181 | except TypeError: 182 | # This occurs in autodoc when torch is being mocked. 183 | _v = "" 184 | 185 | if config.USE_JIT: 186 | script = torch.jit.script 187 | else: 188 | try: 189 | script = torch.jit.script_if_tracing 190 | except AttributeError: 191 | 192 | def script(func): 193 | return func 194 | 195 | 196 | if _v < "1.10.0": 197 | meshgrid = torch.meshgrid 198 | 199 | trunc_divide = torch.floor_divide 200 | else: 201 | 202 | def trunc_divide(input: torch.Tensor, other: Any) -> torch.Tensor: 203 | if not torch.jit.is_scripting(): 204 | return input.div(other, rounding_mode="trunc") 205 | elif torch.jit.isinstance(other, float): 206 | return input.div(other, rounding_mode="trunc") 207 | elif torch.jit.isinstance(other, int): 208 | return input.div(other, rounding_mode="trunc") 209 | elif torch.jit.isinstance(other, torch.Tensor): 210 | return input.div(other, rounding_mode="trunc") 211 | else: 212 | assert False 213 | 214 | def meshgrid(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 215 | x = torch.meshgrid(a, b, indexing="ij") 216 | return x[0], x[1] 217 | 218 | 219 | if _v < "1.8.0": 220 | from torch.distributions.gumbel import euler_constant 221 | 222 | @script 223 | def pad_sequence( 224 | sequences: List[torch.Tensor], 225 | batch_first: bool = False, 226 | padding_value: float = 0.0, 227 | ) -> torch.Tensor: 228 | shape = sequences[0].size() 229 | shape_rest = shape[1:] 230 | lens = [x.size(0) for x in sequences] 231 | max_len = max(lens) 232 | pad_shapes = [(max_len - x,) + shape_rest for x in lens] 233 | sequences = [ 234 | torch.cat( 235 | [ 236 | seq, 237 | torch.full(ps, padding_value, device=seq.device, dtype=seq.dtype), 238 | ], 239 | 0, 240 | ) 241 | for seq, ps in zip(sequences, pad_shapes) 242 | ] 243 | return torch.stack(sequences, 0 if batch_first else 1) 244 | 245 | def linalg_solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 246 | return torch.solve(B, A)[0] 247 | 248 | @torch.jit.unused 249 | def _jit_isinstance(obj: Any, x: type) -> bool: 250 | if isinstance(obj, torch.nn.utils.rnn.PackedSequence): 251 | obj = obj.data, obj.batch_sizes, obj.sorted_indices, obj.unsorted_indices 252 | origin = getattr(x, "__origin__", None) 253 | if origin is None: 254 | return isinstance(obj, x) 255 | if origin in {tuple, list, set, List, Set, Tuple}: 256 | args = getattr(x, "__args__", None) 257 | if not args: 258 | return ( 259 | (origin in {tuple, Tuple} and obj == tuple()) 260 | or (origin in {list, List} and obj == list()) 261 | or (origin in {set, Set} and obj == set()) 262 | ) 263 | if origin in {tuple, Tuple}: 264 | return (len(obj) is len(args)) and all( 265 | _jit_isinstance(*y) for y in zip(obj, args) 266 | ) 267 | else: 268 | assert len(args) == 1 269 | return all(_jit_isinstance(o, args[0]) for o in obj) 270 | elif origin is Union: 271 | args = x.__args__ 272 | return any(_jit_isinstance(obj, y) for y in args) 273 | return False 274 | 275 | def jit_isinstance(obj: Any, x: Any) -> bool: 276 | if torch.jit.is_scripting(): 277 | raise RuntimeError("Refinement isn't possible with this version of pytorch") 278 | else: 279 | return _jit_isinstance(obj, x) 280 | 281 | from torch.distributions.constraints import Constraint 282 | 283 | class one_hot(Constraint): 284 | is_discrete = True 285 | event_dim = 1 286 | 287 | def check(self, value): 288 | is_boolean = (value == 0) | (value == 1) 289 | is_normalized = value.sum(-1).eq(1) 290 | return is_boolean.all(-1) & is_normalized 291 | 292 | 293 | else: 294 | from torch.distributions.utils import euler_constant 295 | from torch.distributions.constraints import one_hot 296 | 297 | if config.USE_JIT: 298 | script = torch.jit.script 299 | else: 300 | script = torch.jit.script_if_tracing 301 | 302 | pad_sequence = torch.nn.utils.rnn.pad_sequence 303 | linalg_solve = torch.linalg.solve 304 | jit_isinstance = torch.jit.isinstance 305 | 306 | 307 | if _v < "1.7.0": 308 | 309 | @script 310 | def movedim(a: torch.Tensor, source: int, dest: int) -> torch.Tensor: 311 | D = a.ndim 312 | if source < -D or source >= D: 313 | raise RuntimeError( 314 | f"Dimension 'source' expected to be in the range [{-D},{D - 1}], " 315 | f"got {source}" 316 | ) 317 | source = (source + D) % D 318 | if dest < -D or dest >= D: 319 | raise RuntimeError( 320 | f"Dimension 'dest' expected to be in the range [{-D},{D - 1}], " 321 | f"got {dest}" 322 | ) 323 | dest = (dest + D) % D 324 | if source < dest: 325 | for s in range(source, dest): 326 | a = a.transpose(s, s + 1) 327 | elif source > dest: 328 | for s in range(source, dest, -1): 329 | a = a.transpose(s - 1, s) 330 | return a 331 | 332 | 333 | else: 334 | movedim = torch.movedim 335 | 336 | 337 | if _v < "1.6.0": 338 | 339 | def logaddexp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 340 | max_, min_ = torch.max(a, b), torch.min(a, b) 341 | return torch.where( 342 | torch.isfinite(max_), (min_ - max_).exp().log1p() + max_, max_ 343 | ) 344 | 345 | 346 | else: 347 | logaddexp = torch.logaddexp 348 | 349 | 350 | def broadcast_shapes(a: List[int], b: List[int]) -> List[int]: 351 | scalar = torch.zeros((), device="cpu") 352 | tensor_a = scalar.expand(a) 353 | tensor_b = scalar.expand(b) 354 | tensor_a, tensor_b = torch.broadcast_tensors(tensor_a, tensor_b) 355 | return tensor_a.shape 356 | 357 | 358 | @script 359 | def unflatten(x: torch.Tensor, dim: int, shape: List[int]) -> torch.Tensor: 360 | ndim = x.dim() 361 | if dim < -ndim or dim > ndim - 1: 362 | raise RuntimeError(f"Expected dim to be between [{-ndim},{ndim-1}], got {dim}") 363 | dim = (dim + ndim) % ndim 364 | full_shape = list(x.shape) 365 | full_shape = full_shape[:dim] + shape + full_shape[dim + 1 :] 366 | return x.view(full_shape) 367 | -------------------------------------------------------------------------------- /src/pydrobert/torch/_enumerate_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from ._estimators import Estimator, FunctionOnSample 18 | 19 | 20 | class EnumerateEstimator(Estimator): 21 | r"""Calculate expectation exactly by enumerating the support of the distribution 22 | 23 | An unbiased, zero-variance "estimate" of an expectation over a discrete variable 24 | may be calculated brute force by enumerating the support and taking the product of 25 | function values with their probabilities under the distribution. 26 | 27 | .. math:: 28 | 29 | v = \mathbb{E}_{b \sim P}[f(b)] = \sum_b P(b) f(b). 30 | 31 | When called, the instance does just that. 32 | 33 | Parameters 34 | ---------- 35 | proposal 36 | The distribution over which the expectation is taken, :math:`P`. Must be able to 37 | enumerate its support through 38 | :func:`torch.distributions.Distribution.enumerate_support` 39 | (``proposal.has_enumerate_support == True``). 40 | func 41 | is_log 42 | 43 | Returns 44 | ------- 45 | v : torch.Tensor 46 | 47 | Warnings 48 | -------- 49 | The call may be both compute- and memory-intensive, depending on the size of the 50 | support. 51 | """ 52 | 53 | return_log: bool 54 | 55 | def __init__( 56 | self, 57 | proposal: torch.distributions.distribution.Distribution, 58 | func: FunctionOnSample, 59 | is_log: bool = False, 60 | ) -> None: 61 | if not proposal.has_enumerate_support: 62 | raise ValueError( 63 | "proposal must be able to enumerate its support " 64 | "(proposal.has_enumerate_support == True)" 65 | ) 66 | super().__init__(proposal, func, is_log) 67 | 68 | def __call__(self) -> torch.Tensor: 69 | b = self.proposal.enumerate_support() 70 | log_pb = self.proposal.log_prob(b) 71 | fb = self.func(b) 72 | if self.is_log: 73 | v = fb + log_pb 74 | v = v.logsumexp(0) 75 | else: 76 | v = (fb * log_pb.exp()).sum(0) 77 | return v 78 | 79 | -------------------------------------------------------------------------------- /src/pydrobert/torch/_estimators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | from typing import Callable 18 | from typing_extensions import TypeAlias 19 | 20 | import torch 21 | 22 | from . import argcheck 23 | 24 | FunctionOnSample: TypeAlias = Callable[[torch.Tensor], torch.Tensor] 25 | 26 | 27 | class Estimator(metaclass=abc.ABCMeta): 28 | r"""Computes an estimate of an expectation 29 | 30 | An estimator estimates the value of a function :math:`f` integrated over a 31 | probability density :math:`P` 32 | 33 | .. math:: 34 | 35 | v = \mathbb{E}_{b \sim P}\left[f(b)\right] 36 | = \int_{b \in \mathrm{supp}(P)} f(b) \mathrm{d}P(b). 37 | 38 | The value of :math:`v` can be estimated in many ways. This base class serves as the 39 | common foundation for those estimators. The usage pattern is as follows: 40 | 41 | .. code-block:: python 42 | 43 | def func(b): 44 | # return the value of f(b) here 45 | 46 | # ... 47 | # training loop 48 | for epoch in range(num_epochs): 49 | # ... 50 | # 1. Determine parameterization (e.g. logits) from inputs. 51 | # 2. Initialize the distribution and estimator in the training loop. 52 | dist = torch.distributions.SomeDistribution(logits=logits) 53 | estimator = pydrobert.torch.estimators.SomeEstimator(dist, func, ...) 54 | v = estimator() # of shape dist.batch_shape 55 | # 3. calculate loss as a function of v 56 | loss.backwards() 57 | # ... 58 | 59 | Parameters 60 | ---------- 61 | proposal 62 | The distribution over which the expectation is taken. This is usually but not 63 | always :math:`P` (see :class:`ImportanceSamplingEstimator` for a 64 | counterexample). 65 | func 66 | The function :math:`f`. A callable (such as a :class:`pydrobert.torch.Module`) 67 | which accepts a sample tensor as input of shape ``(num_samples,) + 68 | proposal.batch_shape + proposal.event_shape`` and returns a tensor of shape 69 | ``(num_samples,) + proposal.batch_shape``. 70 | is_log 71 | If :obj:`True`, the estimator operates in log space. `func` defines :math:`\log 72 | f` instead of :math:`f` and the return value `v` represents an estimate of 73 | :math:`\log v`. Estimators will often be more numerically stable in log space. 74 | 75 | Returns 76 | ------- 77 | v : torch.Tensor 78 | An estimate of :math:`v`. Of shape ``proposal.batch_shape``. 79 | 80 | Notes 81 | ----- 82 | An estimator is not a :class:`torch.nn.Module` and is not in general safe to be 83 | JIT scripted or traced. The parameterization of the proposal distribution is usually 84 | output 85 | """ 86 | 87 | proposal: torch.distributions.distribution.Distribution 88 | func: FunctionOnSample 89 | is_log: bool 90 | 91 | def __init__( 92 | self, 93 | proposal: torch.distributions.distribution.Distribution, 94 | func: FunctionOnSample, 95 | is_log: bool = False, 96 | ): 97 | proposal = argcheck.is_a( 98 | proposal, torch.distributions.distribution.Distribution, "proposal" 99 | ) 100 | is_log = argcheck.is_bool(is_log, "is_log") 101 | super().__init__() 102 | self.proposal, self.func, self.is_log = proposal, func, is_log 103 | 104 | @abc.abstractmethod 105 | def __call__(self) -> torch.Tensor: 106 | raise NotImplementedError 107 | -------------------------------------------------------------------------------- /src/pydrobert/torch/_rl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from . import argcheck 18 | from ._compat import script 19 | from ._wrappers import functional_wrapper, proxy 20 | 21 | 22 | @script 23 | @functional_wrapper("TimeDistributedReturn") 24 | def time_distributed_return( 25 | r: torch.Tensor, gamma: float, batch_first: bool = False 26 | ) -> torch.Tensor: 27 | if r.dim() != 2: 28 | raise RuntimeError("r must be 2 dimensional") 29 | if not gamma: 30 | return r 31 | if batch_first: 32 | exp = torch.arange(r.size(1), device=r.device, dtype=r.dtype) 33 | discount = torch.pow(gamma, exp) 34 | discount = (discount.unsqueeze(1) / discount.unsqueeze(0)).tril() 35 | R = torch.matmul(r, discount) 36 | else: 37 | exp = torch.arange(r.size(0), device=r.device, dtype=r.dtype) 38 | discount = torch.pow(gamma, exp) 39 | discount = (discount.unsqueeze(0) / discount.unsqueeze(1)).triu() 40 | R = torch.matmul(discount, r) 41 | return R 42 | 43 | 44 | class TimeDistributedReturn(torch.nn.Module): 45 | r"""Accumulate future local rewards at every time step 46 | 47 | In `reinforcement learning 48 | `__, the return is defined as 49 | the sum of discounted future rewards. This function calculates the return for a 50 | given time step :math:`t` as 51 | 52 | .. math:: 53 | 54 | R_t = \sum_{t'=t} \gamma^(t' - t) r_{t'} 55 | 56 | Where :math:`r_{t'}` gives the (local) reward at time :math:`t'` and :math:`\gamma` 57 | is the discount factor. :math:`\gamma \in [0, 1)` implies convergence, but this is 58 | not enforced here. 59 | 60 | Parameters 61 | ---------- 62 | gamma 63 | The discount factor :math:`\gamma`. 64 | batch_first 65 | Transposes the dimensions of `r` and `R` if :obj:`True`. 66 | 67 | Call Parameters 68 | --------------- 69 | r : torch.Tensor 70 | A tensor of shape ``(T, N)`` of local rewards, where ``T`` is the sequence size 71 | and ``N`` is the batch size. The local rewards :math:`r`. 72 | 73 | Returns 74 | ------- 75 | R : torch.Tensor 76 | A tensor of shape ``(T, N)`` of the time-distributed rewards. 77 | """ 78 | 79 | __constants__ = "gamma", "batch_first" 80 | 81 | gamma: float 82 | batch_first: bool 83 | 84 | def __init__(self, gamma: float, batch_first: bool): 85 | gamma = argcheck.is_float(gamma, "gamma") 86 | batch_first = argcheck.is_bool(batch_first, "batch_first") 87 | super().__init__() 88 | self.gamma, self.batch_first = gamma, batch_first 89 | 90 | def extra_repr(self) -> str: 91 | return f"gamma={self.gamma},batch_first={self.batch_first}" 92 | 93 | def forward(self, r: torch.Tensor) -> torch.Tensor: 94 | return time_distributed_return(r, self.gamma, self.batch_first) 95 | 96 | __call__ = proxy(forward) 97 | -------------------------------------------------------------------------------- /src/pydrobert/torch/_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # proxy(...) is from 4 | # https://medium.com/@ppeetteerrs/adding-type-hints-to-pytorch-call-function-30728a972392 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import Callable, TypeVar, cast 19 | from inspect import signature 20 | 21 | _FUNCTIONAL_DOC_TEMPLATE = """Functional version of {module_name} 22 | 23 | This function accepts both the arguments initializing a :class:`{module_name}` instance 24 | and the inputs to its call and outputs the return value of the call. 25 | 26 | Parameters 27 | ---------- 28 | {parameters} 29 | {returns} 30 | See Also 31 | -------- 32 | pydrobert.torch.modules.{module_name} 33 | For a description of what this does, its inputs, and its outputs. 34 | """ 35 | 36 | C = TypeVar("C", bound=Callable) 37 | 38 | 39 | def functional_wrapper(module_name: str) -> Callable[[C], C]: 40 | def decorator(func): 41 | sig = signature(func) 42 | parameters = "\n".join(sig.parameters) 43 | returns = "" 44 | if sig.return_annotation is not sig.empty: 45 | returns = f""" 46 | 47 | Returns 48 | ------- 49 | {str(sig.return_annotation)} 50 | 51 | """ 52 | func.__doc__ = _FUNCTIONAL_DOC_TEMPLATE.format( 53 | parameters=parameters, module_name=decorator.__modname, returns=returns 54 | ) 55 | return func 56 | 57 | decorator.__modname = module_name 58 | 59 | return decorator 60 | 61 | 62 | def proxy(f: C) -> C: 63 | return cast(C, lambda self, *x, **y: super(self.__class__, self).__call__(*x, **y)) 64 | -------------------------------------------------------------------------------- /src/pydrobert/torch/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Package constants used throughout pydrobert.torch 16 | 17 | If this submodule is imported first in :mod:`pydrobert.torch` and any values are 18 | changed, the changes will propagate to any submodules. 19 | 20 | This list is non-exhaustive; types or functions may have their own defaults. 21 | """ 22 | 23 | import os 24 | import math 25 | 26 | __all__ = [ 27 | "DEFT_ALI_SUBDIR", 28 | "DEFT_CHUNK_SIZE", 29 | "DEFT_CTM_CHANNEL", 30 | "DEFT_DEL_COST", 31 | "DEFT_FEAT_SUBDIR", 32 | "DEFT_FILE_PREFIX", 33 | "DEFT_FILE_SUFFIX", 34 | "DEFT_FLOAT_PRINT_PRECISION", 35 | "DEFT_FRAME_SHIFT_MS", 36 | "DEFT_HYP_SUBDIR", 37 | "DEFT_INS_COST", 38 | "DEFT_NUM_WORKERS", 39 | "DEFT_PAD_VALUE", 40 | "DEFT_PDFS_SUBDIR", 41 | "DEFT_REF_SUBDIR", 42 | "DEFT_SUB_COST", 43 | "DEFT_TEXTGRID_SUFFIX", 44 | "DEFT_TEXTGRID_TIER_ID", 45 | "DEFT_TEXTGRID_TIER_NAME", 46 | "EPS_0", 47 | "EPS_INF", 48 | "EPS_NINF", 49 | "INDEX_PAD_VALUE", 50 | "TINY", 51 | "USE_JIT", 52 | ] 53 | 54 | 55 | INDEX_PAD_VALUE = -100 56 | """The value to pad index-based tensors with 57 | 58 | Batched operations often involve variable-width input. This value is used to 59 | right-pad indexed-based tensors with to indicate that this element should be 60 | ignored. 61 | 62 | The default value (:obj:`-100`) was chosen to coincide with the PyTorch 1.0 default 63 | for ``ignore_index`` in the likelihood losses 64 | """ 65 | 66 | TINY = 1.1754943508222875e-38 67 | """Smallest representable floating-point integer""" 68 | 69 | 70 | USE_JIT = os.environ.get("PYTORCH_JIT", None) == "1" 71 | """Whether to eagerly compile functions with JIT 72 | 73 | If :obj:`True`, :mod:`pydrobert.torch` compile all functions it can with JIT on import. 74 | Otherwise, if using PyTorch >= 1.8.0, relevant items will be decorated with 75 | :func:`torch.jit.script_if_tracing`. The default is :obj:`True` if and only if the 76 | environment variable ``PYTORCH_JIT=1``. 77 | """ 78 | 79 | EPS_NINF = math.log(1.1754943508222875e-38) / 2 80 | """A small enough value in log space that exponentiating it is very close to zero 81 | 82 | This number is sometimes used in place of -infinity in log-space values to avoid 83 | masking. Increasing it will decrease the accuracy of computations, but may avoid NaNs. 84 | """ 85 | 86 | EPS_0 = math.log1p(-2 * 1.1920928955078125e-07) 87 | """A large enough value in log space that exponentiating it is very close to 1 88 | 89 | This number is sometimes used in place of 0 in log-space values to avoid masking. 90 | Decreasing it will decrease the accuracy of computations, but may avoid NaNs. 91 | """ 92 | 93 | EPS_INF = math.log(3.4028234663852886e38) / 2 94 | """A large enough value in log space that exponentiating it is near infinity 95 | 96 | This number is sometimes used in place of infinity in log-space values to avoid masking. 97 | Decreasing it will decrease the accuracy of computations, but may avoid NaNs. 98 | """ 99 | 100 | DEFT_FRAME_SHIFT_MS = 10.0 101 | """The default frame shift in milliseconds for commands""" 102 | 103 | DEFT_TEXTGRID_SUFFIX = ".TextGrid" 104 | """The default suffix indicating TextGrid files for commands""" 105 | 106 | DEFT_CHUNK_SIZE = 1000 107 | """Default number of units to process at once when performing multiprocessing""" 108 | 109 | 110 | def _cpu_count() -> int: 111 | if hasattr(os, "sched_getaffinity"): 112 | return len(os.sched_getaffinity(0)) 113 | cpu_count = os.cpu_count() 114 | return 0 if cpu_count is None else cpu_count 115 | 116 | 117 | DEFT_NUM_WORKERS = _cpu_count() 118 | """Default number of workers when performing multiprocessing""" 119 | 120 | DEFT_FILE_PREFIX = "" 121 | """Default prefix of a torch data file""" 122 | 123 | DEFT_FILE_SUFFIX = ".pt" 124 | """Default suffix of a torch data file""" 125 | 126 | DEFT_FLOAT_PRINT_PRECISION = 3 127 | """Default precision to write floating point values to file with""" 128 | 129 | DEFT_CTM_CHANNEL = "A" 130 | """Default channel to write to CTM files""" 131 | 132 | DEFT_TEXTGRID_TIER_ID = 0 133 | """Default TextGrid tier to read transcripts from""" 134 | 135 | DEFT_TEXTGRID_TIER_NAME = "transcript" 136 | """Default TextGrid tiear to write transcripts to""" 137 | 138 | DEFT_FEAT_SUBDIR = "feat" 139 | """Default subdirectory of a torch data directory containing features""" 140 | 141 | DEFT_ALI_SUBDIR = "ali" 142 | """Default subdirectory of a torch data directory containing alignments""" 143 | 144 | DEFT_REF_SUBDIR = "ref" 145 | """Default subdirectory of a torch data directory containing reference tokens""" 146 | 147 | DEFT_PDFS_SUBDIR = "pdfs" 148 | """Default subdirectory of a torch data directory to write pdfs to""" 149 | 150 | DEFT_HYP_SUBDIR = "hyp" 151 | """Default subdirectory of a torch data directory to write hypothesis tokens to""" 152 | 153 | DEFT_PAD_VALUE = 0.0 154 | """Default value to pad floating-point tensors with""" 155 | 156 | DEFT_INS_COST = 1.0 157 | """Default insertion cost in error rate/distance computations""" 158 | 159 | DEFT_DEL_COST = 1.0 160 | """Default deletion cost in error rate/distance computations""" 161 | 162 | DEFT_SUB_COST = 1.0 163 | """Default substitution cost in error rate/distance computations""" 164 | -------------------------------------------------------------------------------- /src/pydrobert/torch/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Classes and functions related to storing/retrieving speech data""" 16 | 17 | import functools 18 | import warnings 19 | 20 | from ._datasets import ( 21 | ContextWindowDataParams, 22 | ContextWindowDataSet, 23 | extract_window, 24 | LangDataParams, 25 | LangDataSet, 26 | SpectDataParams, 27 | SpectDataSet, 28 | validate_spect_data_set, 29 | ) 30 | from ._dataloaders import ( 31 | AbstractEpochSampler, 32 | BucketBatchSampler, 33 | context_window_seq_to_batch, 34 | ContextWindowDataLoader, 35 | ContextWindowDataLoaderParams, 36 | ContextWindowEvaluationDataLoader, # deprecated 37 | ContextWindowTrainingDataLoader, # deprecated 38 | DataLoaderParams, 39 | DynamicLengthDataLoaderParams, 40 | EpochRandomSampler, 41 | EpochSequentialSampler, 42 | lang_seq_to_batch, 43 | LangDataLoader, 44 | LangDataLoaderParams, 45 | spect_seq_to_batch, 46 | SpectDataLoader, 47 | SpectDataLoaderParams, 48 | SpectEvaluationDataLoader, # deprecated 49 | SpectTrainingDataLoader, # deprecated 50 | ) 51 | from ._parsing import ( 52 | parse_arpa_lm, 53 | read_ctm, 54 | read_textgrid, 55 | read_trn_iter, 56 | read_trn, 57 | token_to_transcript, 58 | transcript_to_token, 59 | write_ctm, 60 | write_textgrid, 61 | write_trn, 62 | ) 63 | 64 | __all__ = [ 65 | "AbstractEpochSampler", 66 | "BucketBatchSampler", 67 | "context_window_seq_to_batch", 68 | "ContextWindowDataLoader", 69 | "ContextWindowDataLoaderParams", 70 | "ContextWindowDataParams", 71 | "ContextWindowDataSet", 72 | "DataLoaderParams", 73 | "DynamicLengthDataLoaderParams", 74 | "EpochRandomSampler", 75 | "EpochSequentialSampler", 76 | "extract_window", 77 | "lang_seq_to_batch", 78 | "LangDataLoader", 79 | "LangDataLoaderParams", 80 | "LangDataParams", 81 | "LangDataSet", 82 | "parse_arpa_lm", 83 | "read_ctm", 84 | "read_textgrid", 85 | "read_trn_iter", 86 | "read_trn", 87 | "spect_seq_to_batch", 88 | "SpectDataLoader", 89 | "SpectDataLoaderParams", 90 | "SpectDataParams", 91 | "SpectDataSet", 92 | "token_to_transcript", 93 | "transcript_to_token", 94 | "validate_spect_data_set", 95 | "write_ctm", 96 | "write_textgrid", 97 | "write_trn", 98 | ] 99 | 100 | 101 | def import_and_deprecate(cls): 102 | from . import _dataloaders 103 | 104 | old_name = cls.__name__ 105 | new_name = old_name.replace("DataSet", "DataLoader") 106 | cls = getattr(_dataloaders, new_name) 107 | 108 | @functools.wraps(cls) 109 | def wraps(*args, **kwargs): 110 | warnings.warn( 111 | f"The name '{wraps.__old}' is deprecated. Please swith to '{wraps.__new}'", 112 | DeprecationWarning, 113 | ) 114 | return wraps.__cls(*args, **kwargs) 115 | 116 | wraps.__old = old_name 117 | wraps.__new = new_name 118 | wraps.__cls = cls 119 | 120 | return wraps 121 | 122 | 123 | @import_and_deprecate 124 | class DataSetParams: 125 | pass 126 | 127 | 128 | @import_and_deprecate 129 | class SpectDataSetParams: 130 | pass 131 | 132 | 133 | @import_and_deprecate 134 | class ContextWindowDataSetParams: 135 | pass 136 | -------------------------------------------------------------------------------- /src/pydrobert/torch/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PyTorch distributions and interfaces 16 | 17 | Warnings 18 | -------- 19 | Distributions cannot be JIT scripted or traced. 20 | """ 21 | 22 | from ._combinatorics import ( 23 | BinaryCardinalityConstraint, 24 | SimpleRandomSamplingWithoutReplacement, 25 | ) 26 | from ._decoding import SequentialLanguageModelDistribution, TokenSequenceConstraint 27 | from ._straight_through import ( 28 | ConditionalStraightThrough, 29 | Density, 30 | GumbelOneHotCategorical, 31 | LogisticBernoulli, 32 | StraightThrough, 33 | ) 34 | 35 | __all__ = [ 36 | "BinaryCardinalityConstraint", 37 | "ConditionalStraightThrough", 38 | "Density", 39 | "GumbelOneHotCategorical", 40 | "LogisticBernoulli", 41 | "SequentialLanguageModelDistribution", 42 | "SimpleRandomSamplingWithoutReplacement", 43 | "StraightThrough", 44 | "TokenSequenceConstraint", 45 | ] 46 | -------------------------------------------------------------------------------- /src/pydrobert/torch/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Pytorch functions""" 16 | 17 | from ._combinatorics import ( 18 | binomial_coefficient, 19 | enumerate_vocab_sequences, 20 | enumerate_binary_sequences, 21 | enumerate_binary_sequences_with_cardinality, 22 | simple_random_sampling_without_replacement, 23 | ) 24 | from ._decoding import ( 25 | beam_search_advance, 26 | ctc_greedy_search, 27 | ctc_prefix_search_advance, 28 | random_walk_advance, 29 | sequence_log_probs, 30 | ) 31 | from ._feats import ( 32 | chunk_token_sequences_by_slices, 33 | feat_deltas, 34 | mean_var_norm, 35 | slice_spect_data, 36 | ) 37 | from ._img import ( 38 | dense_image_warp, 39 | polyharmonic_spline, 40 | random_shift, 41 | sparse_image_warp, 42 | spec_augment_apply_parameters, 43 | spec_augment_draw_parameters, 44 | spec_augment, 45 | warp_1d_grid, 46 | ) 47 | from ._pad import chunk_by_slices, pad_masked_sequence, pad_variable 48 | from ._rl import time_distributed_return 49 | from ._string import ( 50 | edit_distance, 51 | error_rate, 52 | fill_after_eos, 53 | hard_optimal_completion_distillation_loss, 54 | minimum_error_rate_loss, 55 | optimal_completion, 56 | prefix_edit_distances, 57 | prefix_error_rates, 58 | ) 59 | 60 | __all__ = [ 61 | "beam_search_advance", 62 | "binomial_coefficient", 63 | "chunk_by_slices", 64 | "chunk_token_sequences_by_slices", 65 | "ctc_greedy_search", 66 | "ctc_prefix_search_advance", 67 | "dense_image_warp", 68 | "edit_distance", 69 | "enumerate_binary_sequences_with_cardinality", 70 | "enumerate_binary_sequences", 71 | "enumerate_vocab_sequences", 72 | "error_rate", 73 | "feat_deltas", 74 | "fill_after_eos", 75 | "hard_optimal_completion_distillation_loss", 76 | "mean_var_norm", 77 | "minimum_error_rate_loss", 78 | "optimal_completion", 79 | "pad_masked_sequence", 80 | "pad_variable", 81 | "polyharmonic_spline", 82 | "prefix_edit_distances", 83 | "prefix_error_rates", 84 | "random_shift", 85 | "random_walk_advance", 86 | "sequence_log_probs", 87 | "simple_random_sampling_without_replacement", 88 | "slice_spect_data", 89 | "sparse_image_warp", 90 | "spec_augment_apply_parameters", 91 | "spec_augment_draw_parameters", 92 | "spec_augment", 93 | "time_distributed_return", 94 | "warp_1d_grid", 95 | ] 96 | -------------------------------------------------------------------------------- /src/pydrobert/torch/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | warnings.warn( 18 | "pydrobert.torch.layers is deprecated. Use pydrobert.torch.functional for " 19 | "functions and pydrobert.torch.modules for modules", 20 | DeprecationWarning, 21 | 2, 22 | ) 23 | 24 | from ._attn import ( 25 | ConcatSoftAttention, 26 | DotProductSoftAttention, 27 | GeneralizedDotProductSoftAttention, 28 | GlobalSoftAttention, 29 | MultiHeadedAttention, 30 | ) 31 | from ._decoding import BeamSearch, CTCPrefixSearch, SequenceLogProbabilities 32 | from ._img import ( 33 | DenseImageWarp, 34 | PadVariable, 35 | PolyharmonicSpline, 36 | random_shift, 37 | RandomShift, 38 | SparseImageWarp, 39 | spec_augment_apply_parameters, 40 | spec_augment_draw_parameters, 41 | spec_augment, 42 | SpecAugment, 43 | Warp1DGrid, 44 | ) 45 | from ._lm import ( 46 | ExtractableSequentialLanguageModel, 47 | LookupLanguageModel, 48 | MixableSequentialLanguageModel, 49 | SequentialLanguageModel, 50 | ) 51 | from ._rl import TimeDistributedReturn 52 | from ._string import ( 53 | EditDistance, 54 | ErrorRate, 55 | hard_optimal_completion_distillation_loss, 56 | HardOptimalCompletionDistillationLoss, 57 | minimum_error_rate_loss, 58 | MinimumErrorRateLoss, 59 | OptimalCompletion, 60 | PrefixEditDistances, 61 | PrefixErrorRates, 62 | ) 63 | -------------------------------------------------------------------------------- /src/pydrobert/torch/lightning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions and classes which interface with :mod:`pytorch_lightning` 16 | 17 | This functionality is WIP. 18 | 19 | See `scpc `_ for a working example. 20 | 21 | Raises 22 | ------ 23 | ImportError 24 | If :mod:`pytorch_lightning` is not installed. 25 | """ 26 | 27 | from ._pl_data import ( 28 | LitDataModule, 29 | LitDataModuleParams, 30 | LitDataModuleParamsMetaclass, 31 | LitSpectDataModule, 32 | LitSpectDataModuleParams, 33 | ) 34 | 35 | __all__ = [ 36 | "LitDataModule", 37 | "LitDataModuleParams", 38 | "LitDataModuleParamsMetaclass", 39 | "LitSpectDataModule", 40 | "LitSpectDataModuleParams", 41 | ] 42 | -------------------------------------------------------------------------------- /src/pydrobert/torch/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom PyTorch modules 16 | 17 | Notes 18 | ----- 19 | To document :class:`torch.nn.Module` subclasses, we add a special heading called "Call 20 | Parameters" to the docstring which, along with "Returns", specify the signature of the 21 | module's :func:`__call__` method. The header "Parameters" refers to what values the 22 | module are initialized with. The general usage pattern is: 23 | 24 | >>> module = Module(*params) 25 | >>> returns = module(*call_params) 26 | """ 27 | 28 | __all__ = [ 29 | "BeamSearch", 30 | "ChunkBySlices", 31 | "ChunkTokenSequencesBySlices", 32 | "ConcatSoftAttention", 33 | "CTCGreedySearch", 34 | "CTCPrefixSearch", 35 | "DenseImageWarp", 36 | "DotProductSoftAttention", 37 | "EditDistance", 38 | "ErrorRate", 39 | "ExtractableSequentialLanguageModel", 40 | "ExtractableShallowFusionLanguageModel", 41 | "FeatureDeltas", 42 | "FillAfterEndOfSequence", 43 | "GeneralizedDotProductSoftAttention", 44 | "GlobalSoftAttention", 45 | "GumbelOneHotCategoricalRebarControlVariate", 46 | "HardOptimalCompletionDistillationLoss", 47 | "LogisticBernoulliRebarControlVariate", 48 | "LookupLanguageModel", 49 | "MeanVarianceNormalization", 50 | "MinimumErrorRateLoss", 51 | "MixableSequentialLanguageModel", 52 | "MixableShallowFusionLanguageModel", 53 | "MultiHeadedAttention", 54 | "OptimalCompletion", 55 | "PadMaskedSequence", 56 | "PadVariable", 57 | "PolyharmonicSpline", 58 | "PrefixEditDistances", 59 | "PrefixErrorRates", 60 | "RandomShift", 61 | "RandomWalk", 62 | "SequenceLogProbabilities", 63 | "SequentialLanguageModel", 64 | "ShallowFusionLanguageModel", 65 | "SliceSpectData", 66 | "SparseImageWarp", 67 | "SpecAugment", 68 | "TimeDistributedReturn", 69 | "Warp1DGrid", 70 | ] 71 | 72 | from ._attn import ( 73 | ConcatSoftAttention, 74 | DotProductSoftAttention, 75 | GeneralizedDotProductSoftAttention, 76 | GlobalSoftAttention, 77 | MultiHeadedAttention, 78 | ) 79 | from ._decoding import ( 80 | BeamSearch, 81 | CTCGreedySearch, 82 | CTCPrefixSearch, 83 | RandomWalk, 84 | SequenceLogProbabilities, 85 | ) 86 | from ._feats import ( 87 | ChunkTokenSequencesBySlices, 88 | FeatureDeltas, 89 | MeanVarianceNormalization, 90 | SliceSpectData, 91 | ) 92 | from ._img import ( 93 | DenseImageWarp, 94 | PolyharmonicSpline, 95 | Warp1DGrid, 96 | SparseImageWarp, 97 | RandomShift, 98 | SpecAugment, 99 | ) 100 | from ._lm import ( 101 | ExtractableSequentialLanguageModel, 102 | ExtractableShallowFusionLanguageModel, 103 | LookupLanguageModel, 104 | MixableSequentialLanguageModel, 105 | MixableShallowFusionLanguageModel, 106 | SequentialLanguageModel, 107 | ShallowFusionLanguageModel, 108 | ) 109 | from ._mc import ( 110 | LogisticBernoulliRebarControlVariate, 111 | GumbelOneHotCategoricalRebarControlVariate, 112 | ) 113 | from ._pad import ChunkBySlices, PadVariable, PadMaskedSequence 114 | from ._rl import TimeDistributedReturn 115 | from ._string import ( 116 | EditDistance, 117 | ErrorRate, 118 | FillAfterEndOfSequence, 119 | HardOptimalCompletionDistillationLoss, 120 | MinimumErrorRateLoss, 121 | OptimalCompletion, 122 | PrefixEditDistances, 123 | PrefixErrorRates, 124 | ) 125 | -------------------------------------------------------------------------------- /src/pydrobert/torch/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | warnings.warn( 18 | "pydrobert.torch.util is deprecated. Use pydrobert.torch.functional for " 19 | "functions. parse_arpa_lm has been moved to pydrobert.torch.data.", 20 | DeprecationWarning, 21 | 2, 22 | ) 23 | 24 | import warnings 25 | 26 | from .functional import ( 27 | beam_search_advance, 28 | ctc_greedy_search, 29 | ctc_prefix_search_advance, 30 | dense_image_warp, 31 | edit_distance, 32 | error_rate, 33 | optimal_completion, 34 | pad_variable, 35 | polyharmonic_spline, 36 | prefix_edit_distances, 37 | prefix_error_rates, 38 | random_walk_advance, 39 | sequence_log_probs, 40 | sparse_image_warp, 41 | time_distributed_return, 42 | warp_1d_grid, 43 | ) 44 | from .data import parse_arpa_lm 45 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | import os 17 | import math 18 | import socket 19 | 20 | from zlib import adler32 21 | from contextlib import closing 22 | from shutil import rmtree 23 | 24 | import torch 25 | 26 | import pydrobert.torch.config as config 27 | 28 | import pydrobert.torch._compat as compat 29 | 30 | # command-line tests usually use relatively few utterances. 31 | # Make sure we're not putting them all on one thread. 32 | config.DEFT_CHUNK_SIZE = 10 33 | config.DEFT_NUM_WORKERS = 2 34 | 35 | if compat._v < "1.8.0": 36 | config.USE_JIT = True # "trace" tests won't work otherwise 37 | compat.script = torch.jit.script 38 | compat.unflatten = torch.jit.script(compat.unflatten) 39 | 40 | # don't re-script anything 41 | # https://github.com/pytorch/pytorch/issues/51140 42 | def script(obj, *args, **kwargs): 43 | if isinstance(obj, torch.ScriptFunction) or isinstance(obj, torch.ScriptModule): 44 | return obj 45 | else: 46 | return compat.script(obj) 47 | 48 | torch.jit.script = script 49 | 50 | 51 | @pytest.fixture 52 | def temp_dir(tmp_path): 53 | dir_ = tmp_path / "pytest" 54 | dir_.mkdir() 55 | yield os.fspath(dir_) 56 | rmtree(os.fspath(dir_)) 57 | 58 | 59 | @pytest.fixture( 60 | params=[ 61 | pytest.param("cpu", marks=pytest.mark.cpu), 62 | pytest.param("cuda", marks=pytest.mark.gpu), 63 | ], 64 | scope="session", 65 | ) 66 | def device(request): 67 | if request.param == "cuda": 68 | return torch.device(torch.cuda.current_device()) 69 | else: 70 | return torch.device(request.param) 71 | 72 | 73 | CUDA_AVAIL = torch.cuda.is_available() 74 | 75 | 76 | def find_free_port(): 77 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 78 | s.bind(("localhost", 0)) 79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 80 | return s.getsockname()[1] 81 | 82 | 83 | def pytest_runtest_setup(item): 84 | if any(mark.name == "gpu" for mark in item.iter_markers()): 85 | if not CUDA_AVAIL: 86 | pytest.skip("cuda is not available") 87 | torch.cuda.empty_cache() 88 | # implicitly seeds all tests for the sake of reproducibility 89 | torch.manual_seed(abs(adler32(bytes(item.name, "utf-8")))) 90 | 91 | # for distributed training (doesn't overwrite test) 92 | os.environ.setdefault("MASTER_ADDR", "localhost") 93 | os.environ.setdefault("MASTER_PORT", str(find_free_port())) 94 | 95 | 96 | @pytest.fixture(scope="session") 97 | def populate_torch_dir(): 98 | def _populate_torch_dir( 99 | dr, 100 | num_utts, 101 | min_width=1, 102 | max_width=10, 103 | num_filts=5, 104 | max_ali_class=9, 105 | max_ref_class=99, 106 | include_ali=True, 107 | include_ref=True, 108 | file_prefix="", 109 | file_suffix=".pt", 110 | include_frame_shift=True, 111 | feat_dtype=torch.float, 112 | ): 113 | feat_dir = os.path.join(dr, "feat") 114 | ali_dir = os.path.join(dr, "ali") 115 | ref_dir = os.path.join(dr, "ref") 116 | if not os.path.isdir(feat_dir): 117 | os.makedirs(feat_dir) 118 | if include_ali and not os.path.isdir(ali_dir): 119 | os.makedirs(ali_dir) 120 | if include_ref and not os.path.isdir(ref_dir): 121 | os.makedirs(ref_dir) 122 | feats, feat_sizes, utt_ids = [], [], [] 123 | alis = [] if include_ali else None 124 | refs, ref_sizes = ([], []) if include_ref else (None, None) 125 | utt_id_fmt_str = "{{:0{}d}}".format(int(math.log10(num_utts)) + 1) 126 | for utt_idx in range(num_utts): 127 | utt_id = utt_id_fmt_str.format(utt_idx) 128 | feat_size = torch.randint(min_width, max_width + 1, (1,)).long() 129 | feat_size = feat_size.item() 130 | feat = (torch.rand(feat_size, num_filts) * 1000).to(dtype=feat_dtype) 131 | torch.save(feat, os.path.join(feat_dir, file_prefix + utt_id + file_suffix)) 132 | feats.append(feat) 133 | feat_sizes.append(feat_size) 134 | utt_ids.append(utt_id) 135 | if include_ali: 136 | ali = torch.randint(max_ali_class + 1, (feat_size,)).long() 137 | torch.save( 138 | ali, os.path.join(ali_dir, file_prefix + utt_id + file_suffix) 139 | ) 140 | alis.append(ali) 141 | if include_ref: 142 | ref_size = torch.randint(1, feat_size + 1, (1,)).long().item() 143 | max_ref_length = torch.randint(1, feat_size + 1, (1,)).long() 144 | max_ref_length = max_ref_length.item() 145 | ref = torch.randint(max_ref_class + 1, (ref_size,)).long() 146 | if include_frame_shift: 147 | ref_starts = torch.randint( 148 | feat_size - max_ref_length + 1, (ref_size,) 149 | ).long() 150 | ref_lengths = torch.randint( 151 | 1, max_ref_length + 1, (ref_size,) 152 | ).long() 153 | ref = torch.stack( 154 | [ref, ref_starts, ref_starts + ref_lengths], dim=-1 155 | ) 156 | torch.save( 157 | ref, os.path.join(ref_dir, file_prefix + utt_id + file_suffix) 158 | ) 159 | ref_sizes.append(ref_size) 160 | refs.append(ref) 161 | return feats, alis, refs, feat_sizes, ref_sizes, utt_ids 162 | 163 | return _populate_torch_dir 164 | 165 | 166 | @pytest.fixture( 167 | params=[ 168 | pytest.param("nojit", marks=pytest.mark.nojit), 169 | pytest.param("trace", marks=pytest.mark.trace), 170 | pytest.param("script", marks=pytest.mark.script), 171 | ] 172 | ) 173 | def jit_type(request): 174 | return request.param 175 | -------------------------------------------------------------------------------- /tests/dense_image_warp/flow.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/dense_image_warp/flow.npy -------------------------------------------------------------------------------- /tests/dense_image_warp/img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/dense_image_warp/img.npy -------------------------------------------------------------------------------- /tests/dense_image_warp/warped.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/dense_image_warp/warped.npy -------------------------------------------------------------------------------- /tests/polyharmonic_spline/o1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/polyharmonic_spline/o1.npy -------------------------------------------------------------------------------- /tests/polyharmonic_spline/o2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/polyharmonic_spline/o2.npy -------------------------------------------------------------------------------- /tests/polyharmonic_spline/o3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/polyharmonic_spline/o3.npy -------------------------------------------------------------------------------- /tests/polyharmonic_spline/q.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/polyharmonic_spline/q.npy -------------------------------------------------------------------------------- /tests/polyharmonic_spline/x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/polyharmonic_spline/x.npy -------------------------------------------------------------------------------- /tests/polyharmonic_spline/y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/polyharmonic_spline/y.npy -------------------------------------------------------------------------------- /tests/republic/README: -------------------------------------------------------------------------------- 1 | This test data is based off two sources. 2 | 3 | First, the Project Gutenberg record of Plato's Republic 4 | (https://www.gutenberg.org/cache/epub/1497/pg1497.txt). The license can be 5 | found at the end of the file. 6 | 7 | Second, we use KenLM (https://kheafield.com/code/kenlm/) to generate the 8 | language model and come up with expected sentence-level log probabilities. 9 | KenLM is licensed under the LGPL, but I'm only distributing the output here 10 | (https://www.gnu.org/licenses/gpl-faq.en.html#WhatCaseIsOutputGPL). If you 11 | plan on using the KenLM code, please be aware of the license. 12 | 13 | We preprocess the text with: 14 | 15 | awk ' 16 | BEGIN {begun=0} 17 | /INTRODUCTION AND ANALYSIS/ {begun=1} 18 | /End of the Project Guten/ {begun=0} 19 | begun {print}' pg1497.txt | \ 20 | tr '\n' ' ' | \ 21 | tr --delete '\r[\200-\377]' | \ 22 | tr '[:upper:]' '[:lower:]' | \ 23 | sed 's/([^)]*)//g' | \ 24 | tr '"'"'"';,:/\-=+*)(' ' ' | \ 25 | tr '?!' '.' | \ 26 | sed 's/\.\.*/\./g' | \ 27 | sed 's/\. /\n/g' | \ 28 | sed 's/\.+/\./g' | \ 29 | tr -s '[:blank:]' | \ 30 | sed 's/^ *//g' > republic.txt 31 | 32 | We restrict the vocabulary to the words that aren't hapax legomena: 33 | 34 | cat republic.txt | \ 35 | tr ' ' '\n' | \ 36 | sort | \ 37 | uniq -c | \ 38 | sort -bgr | \ 39 | awk '$1 > 1 {print $2}' > vocab.txt 40 | 41 | Then convert that to a token2id map: 42 | 43 | cat <(echo ''; echo ''; echo '') vocab.txt | \ 44 | sort -u | \ 45 | awk 'NF {print}' | \ 46 | awk '{print $1, NR-1}' > token2id.map 47 | 48 | Now we use KenLM to generate the language model in ARPA format: 49 | 50 | bin/lmplz \ 51 | -o 5 \ 52 | --limit_vocab_file vocab.txt \ 53 | --text republic.txt > republic.arpa 54 | 55 | And pick some random queries: 56 | 57 | sort -R -u republic.txt | head -n 5 > queries.txt 58 | 59 | Get the expected sentence-level probs: 60 | 61 | bin/query -v sentence republic.arpa < queries.txt | \ 62 | awk '/Total/ {print $2}' > exp.txt 63 | -------------------------------------------------------------------------------- /tests/republic/exp.txt: -------------------------------------------------------------------------------- 1 | -20.089949 2 | -21.444225 3 | -27.39266 4 | -7.997661 5 | -22.825064 6 | -------------------------------------------------------------------------------- /tests/republic/queries.txt: -------------------------------------------------------------------------------- 1 | it is the peremptory military spirit which prevails in the government of honour 2 | or the reaction from the sublime to the ridiculous when glaucon describes the manner in which the new truth will be received by mankind 3 | but with the rich man this is otherwise of him we do not say that he has any specially appointed work which he must perform if he would live 4 | but is such a community possible 5 | he knows that this latter institution is not more than four or five thousand years old may not the end revert to the beginning 6 | -------------------------------------------------------------------------------- /tests/sclite/README: -------------------------------------------------------------------------------- 1 | Hypothesis and reference text are sanitized Lorem Ipsum 2 | (https://en.wikipedia.org/wiki/Lorem_ipsum) 3 | 4 | SCLITE is part of the SCTK toolkit (https://github.com/usnistgov/SCTK) 5 | developed by NIST. The license can be found at 6 | http://www.nist.gov/open/license.cfm. 7 | 8 | The following calls were used to extract the relevant results. Divisions 9 | convert the "percentage" to fractions. 10 | 11 | sclite -r ref.trn -h hyp.trn -i swb > sclite_out.txt 12 | cat sclite_out.txt | grep '^ *| [0-9]' | tr '|' ' ' | awk '{print $1"-"$1,$8/100}' > per_utt.txt 13 | cat sclite_out.txt | grep 'Sum/Avg' | awk '{print $10/100}' > total.txt -------------------------------------------------------------------------------- /tests/sclite/hyp.trn: -------------------------------------------------------------------------------- 1 | sit natoque est orci felis nullam etiam diam senectus mattis mi nec tincidunt (1-1) 2 | lacus (2-2) 3 | platea blandit congue congue varius accumsan malesuada pellentesque (3-3) 4 | ornare venenatis sem dis odio sodales urna tempus pulvinar (4-4) 5 | dolor penatibus cubilia mi sociis sapien platea tempor (5-5) 6 | egestas et ad faucibus magna aenean malesuada fermentum luctus quis lorem sed metus ut nascetur integer iaculis (6-6) 7 | lacus pretium tristique metus sapien blandit adipiscing sapien lacinia netus (7-7) 8 | sollicitudin (8-8) 9 | placerat mus feugiat (9-9) 10 | blandit tincidunt fringilla fringilla felis ultricies (10-10) 11 | a (11-11) 12 | ad sem cubilia nec vitae euismod mus aptent facilisi rutrum cubilia convallis leo donec vehicula varius massa lectus arcu nonummy porta (12-12) 13 | ac (13-13) 14 | dictum (14-14) 15 | nec leo (15-15) 16 | taciti taciti diam senectus purus hendrerit quisque varius feugiat sapien ultrices at luctus (16-16) 17 | sapien orci potenti aliquam placerat (17-17) 18 | a adipiscing molestie quisque curabitur amet penatibus vel nascetur ultricies elit (18-18) 19 | rutrum conubia consequat risus hac ullamcorper gravida torquent commodo gravida at convallis iaculis cras (19-19) 20 | non eleifend inceptos gravida justo nisl (20-20) 21 | convallis sapien auctor praesent sociis non urna (21-21) 22 | laoreet senectus (22-22) 23 | dictumst mi scelerisque amet quis viverra (23-23) 24 | sociosqu elit dui cubilia enim pulvinar (24-24) 25 | sapien ut nec (25-25) 26 | mollis sagittis tristique nostra suspendisse elementum mattis magnis aptent sollicitudin tempor posuere (26-26) 27 | taciti iaculis (27-27) 28 | dignissim magnis (28-28) 29 | quisque consequat metus velit (29-29) 30 | vitae blandit euismod inceptos (30-30) 31 | integer ultrices commodo et erat (31-31) 32 | ante fusce euismod molestie urna integer elit mattis (32-32) 33 | dictum viverra aliquam cubilia molestie mauris luctus lobortis (33-33) 34 | donec euismod tellus vehicula mus pellentesque aenean litora habitant ornare sagittis nibh hymenaeos auctor nisl habitasse faucibus purus suscipit fames eu curabitur vitae auctor ultrices (34-34) 35 | netus dignissim blandit nunc vivamus (35-35) 36 | placerat montes ac adipiscing amet (36-36) 37 | facilisi orci lacus facilisis curabitur (37-37) 38 | sociosqu penatibus praesent (38-38) 39 | lectus scelerisque per semper (39-39) 40 | metus per (40-40) 41 | class platea dolor sociis facilisi habitasse tincidunt penatibus scelerisque per fames tempus (41-41) 42 | congue justo (42-42) 43 | maecenas aliquam iaculis blandit (43-43) 44 | fermentum laoreet nostra cubilia tincidunt sagittis netus nec justo enim faucibus (44-44) 45 | ultricies lacinia magna sagittis gravida nisl platea nulla aenean convallis taciti parturient (45-45) 46 | arcu ultrices consequat (46-46) 47 | maecenas ligula (47-47) 48 | vel suspendisse (48-48) 49 | etiam praesent vestibulum facilisis dapibus magnis potenti massa lorem in fringilla class lectus sem sodales imperdiet convallis taciti nam at pede consequat litora pharetra ad hendrerit nascetur curabitur platea et nisl elementum justo vitae at (49-49) 50 | gravida nonummy magna habitant justo cubilia ornare quam porta elementum (50-50) 51 | -------------------------------------------------------------------------------- /tests/sclite/per_utt.txt: -------------------------------------------------------------------------------- 1 | 1-1 4.333 2 | 2-2 1 3 | 3-3 4 4 | 4-4 2.25 5 | 5-5 7 6 | 6-6 4.25 7 | 7-7 2.5 8 | 8-8 1 9 | 9-9 1 10 | 10-10 1.5 11 | 11-11 1 12 | 12-12 2.222 13 | 13-13 1 14 | 14-14 1 15 | 15-15 1 16 | 16-16 1.333 17 | 17-17 1 18 | 18-18 11 19 | 19-19 1.75 20 | 20-20 1 21 | 21-21 7 22 | 22-22 1 23 | 23-23 0.857 24 | 24-24 1 25 | 25-25 1 26 | 26-26 1.2 27 | 27-27 1 28 | 28-28 1 29 | 29-29 2 30 | 30-30 1 31 | 31-31 1.667 32 | 32-32 4 33 | 33-33 8 34 | 34-34 0.951 35 | 35-35 1 36 | 36-36 1 37 | 37-37 4 38 | 38-38 0.75 39 | 39-39 0.971 40 | 40-40 2 41 | 41-41 0.95 42 | 42-42 1 43 | 43-43 2 44 | 44-44 1 45 | 45-45 2.4 46 | 46-46 1 47 | 47-47 1 48 | 48-48 1 49 | 49-49 17.5 50 | 50-50 0.923 51 | -------------------------------------------------------------------------------- /tests/sclite/ref.trn: -------------------------------------------------------------------------------- 1 | curae torquent ut (1-1) 2 | sollicitudin interdum leo feugiat enim bibendum leo semper laoreet faucibus pharetra (2-2) 3 | enim mattis (3-3) 4 | luctus tempus dapibus tortor (4-4) 5 | penatibus (5-5) 6 | suspendisse convallis odio ultrices (6-6) 7 | tortor ridiculus fusce eros (7-7) 8 | ultrices potenti augue justo nisl nec nostra augue aptent neque natoque sed egestas at phasellus aliquam auctor justo quis sodales commodo porttitor orci (8-8) 9 | rutrum mollis nisl ridiculus malesuada (9-9) 10 | per sociosqu arcu lacinia (10-10) 11 | sapien pellentesque nibh non senectus luctus (11-11) 12 | blandit quisque mauris dapibus semper vivamus et lectus rhoncus (12-12) 13 | neque etiam hendrerit inceptos sed ornare tempus orci natoque faucibus cubilia (13-13) 14 | erat (14-14) 15 | suscipit curae (15-15) 16 | primis morbi accumsan ipsum dapibus dignissim donec sapien cras (16-16) 17 | pretium rhoncus urna nibh pulvinar tellus (17-17) 18 | pellentesque (18-18) 19 | a vehicula interdum aenean integer turpis leo mauris (19-19) 20 | quam vitae ornare curabitur feugiat urna quis dictumst (20-20) 21 | parturient (21-21) 22 | pharetra dolor proin mi consequat (22-22) 23 | nec dictumst tortor etiam auctor taciti ornare ultrices hymenaeos consectetuer sit dolor ultrices viverra (23-23) 24 | enim torquent sollicitudin habitasse justo nascetur nostra curae (24-24) 25 | fames vel pharetra phasellus pede sit dapibus hymenaeos nam (25-25) 26 | cras accumsan lacinia mattis interdum tortor curae sapien facilisi dictum (26-26) 27 | lobortis montes pellentesque netus imperdiet parturient ut volutpat ligula porta duis tortor quis purus cursus ac (27-27) 28 | eleifend proin class eget euismod facilisi senectus egestas tristique primis libero aliquet (28-28) 29 | at mollis (29-29) 30 | vivamus nostra aliquet suscipit dictum penatibus molestie sociosqu cras erat interdum lectus orci porta (30-30) 31 | interdum amet ullamcorper (31-31) 32 | morbi est (32-32) 33 | blandit (33-33) 34 | porta erat luctus semper etiam molestie commodo erat est posuere tortor fusce mi cras integer eros et torquent condimentum elementum lobortis justo primis nam nostra aptent imperdiet habitasse ad faucibus cum donec vitae diam mollis bibendum diam imperdiet nam nisl morbi (34-34) 35 | inceptos quisque at vel aliquet (35-35) 36 | pellentesque a augue fringilla bibendum non (36-36) 37 | orci (37-37) 38 | cubilia vitae praesent lacus (38-38) 39 | vivamus fames urna blandit feugiat lobortis ut penatibus commodo tristique class sed mi iaculis dictum id diam nisi integer venenatis eu a lectus mi neque senectus senectus pede ridiculus feugiat pulvinar lectus leo pharetra suspendisse (39-39) 40 | sapien (40-40) 41 | faucibus sem augue ultricies urna ut parturient ad ornare eget est libero varius arcu lacinia morbi arcu fames ac proin luctus libero nostra tempus eu nibh dolor scelerisque phasellus quam lobortis integer faucibus fames praesent vitae est nonummy facilisis natoque (41-41) 42 | donec magna conubia vivamus potenti fames interdum dis convallis placerat mus curabitur fusce ultricies urna hac rhoncus purus (42-42) 43 | eu nec (43-43) 44 | imperdiet placerat quam et erat dui et velit at praesent dignissim pretium ut lacus vulputate purus fames facilisis aenean id rhoncus (44-44) 45 | id tellus eleifend at senectus (45-45) 46 | nullam facilisis donec vehicula urna pede platea inceptos nam dui iaculis cubilia mollis cubilia semper nisi vitae posuere dapibus mollis taciti nibh quisque vitae blandit fusce vehicula felis conubia semper duis (46-46) 47 | integer sagittis mauris (47-47) 48 | posuere est mi euismod mus rutrum est (48-48) 49 | pretium ultricies (49-49) 50 | nullam mi interdum bibendum quam auctor euismod sagittis parturient magnis rutrum rhoncus elementum (50-50) 51 | -------------------------------------------------------------------------------- /tests/sclite/sclite_out.txt: -------------------------------------------------------------------------------- 1 | sclite: 2.10 TK Version 1.3 2 | Begin alignment of Ref File: 'ref.trn' and Hyp File: 'hyp.trn' 3 | Alignment# 1 for speaker 1 4 | Alignment# 1 for speaker 2 5 | Alignment# 1 for speaker 3 6 | Alignment# 1 for speaker 4 7 | Alignment# 1 for speaker 5 8 | Alignment# 1 for speaker 6 9 | Alignment# 1 for speaker 7 10 | Alignment# 1 for speaker 8 11 | Alignment# 1 for speaker 9 12 | Alignment# 1 for speaker 10 13 | Alignment# 1 for speaker 11 14 | Alignment# 1 for speaker 12 15 | Alignment# 1 for speaker 13 16 | Alignment# 1 for speaker 14 17 | Alignment# 1 for speaker 15 18 | Alignment# 1 for speaker 16 19 | Alignment# 1 for speaker 17 20 | Alignment# 1 for speaker 18 21 | Alignment# 1 for speaker 19 22 | Alignment# 1 for speaker 20 23 | Alignment# 1 for speaker 21 24 | Alignment# 1 for speaker 22 25 | Alignment# 1 for speaker 23 26 | Alignment# 1 for speaker 24 27 | Alignment# 1 for speaker 25 28 | Alignment# 1 for speaker 26 29 | Alignment# 1 for speaker 27 30 | Alignment# 1 for speaker 28 31 | Alignment# 1 for speaker 29 32 | Alignment# 1 for speaker 30 33 | Alignment# 1 for speaker 31 34 | Alignment# 1 for speaker 32 35 | Alignment# 1 for speaker 33 36 | Alignment# 1 for speaker 34 37 | Alignment# 1 for speaker 35 38 | Alignment# 1 for speaker 36 39 | Alignment# 1 for speaker 37 40 | Alignment# 1 for speaker 38 41 | Alignment# 1 for speaker 39 42 | Alignment# 1 for speaker 40 43 | Alignment# 1 for speaker 41 44 | Alignment# 1 for speaker 42 45 | Alignment# 1 for speaker 43 46 | Alignment# 1 for speaker 44 47 | Alignment# 1 for speaker 45 48 | Alignment# 1 for speaker 46 49 | Alignment# 1 for speaker 47 50 | Alignment# 1 for speaker 48 51 | Alignment# 1 for speaker 49 52 | Alignment# 1 for speaker 50 53 | 54 | 55 | 56 | 57 | SYSTEM SUMMARY PERCENTAGES by SPEAKER 58 | 59 | ,------------------------------------------------------------------. 60 | | hyp.trn | 61 | |------------------------------------------------------------------| 62 | | SPKR | # Snt # Wrd | Corr Sub Del Ins Err S.Err | 63 | |--------+-------------+-------------------------------------------| 64 | | 1 | 1 3 | 0.0 100.0 0.0 333.3 433.3 100.0 | 65 | |--------+-------------+-------------------------------------------| 66 | | 2 | 1 11 | 0.0 9.1 90.9 0.0 100.0 100.0 | 67 | |--------+-------------+-------------------------------------------| 68 | | 3 | 1 2 | 0.0 100.0 0.0 300.0 400.0 100.0 | 69 | |--------+-------------+-------------------------------------------| 70 | | 4 | 1 4 | 25.0 50.0 25.0 150.0 225.0 100.0 | 71 | |--------+-------------+-------------------------------------------| 72 | | 5 | 1 1 |100.0 0.0 0.0 700.0 700.0 100.0 | 73 | |--------+-------------+-------------------------------------------| 74 | | 6 | 1 4 | 0.0 100.0 0.0 325.0 425.0 100.0 | 75 | |--------+-------------+-------------------------------------------| 76 | | 7 | 1 4 | 0.0 100.0 0.0 150.0 250.0 100.0 | 77 | |--------+-------------+-------------------------------------------| 78 | | 8 | 1 23 | 0.0 4.3 95.7 0.0 100.0 100.0 | 79 | |--------+-------------+-------------------------------------------| 80 | | 9 | 1 5 | 0.0 60.0 40.0 0.0 100.0 100.0 | 81 | |--------+-------------+-------------------------------------------| 82 | | 10 | 1 4 | 0.0 100.0 0.0 50.0 150.0 100.0 | 83 | |--------+-------------+-------------------------------------------| 84 | | 11 | 1 6 | 0.0 16.7 83.3 0.0 100.0 100.0 | 85 | |--------+-------------+-------------------------------------------| 86 | | 12 | 1 9 | 11.1 88.9 0.0 133.3 222.2 100.0 | 87 | |--------+-------------+-------------------------------------------| 88 | | 13 | 1 11 | 0.0 9.1 90.9 0.0 100.0 100.0 | 89 | |--------+-------------+-------------------------------------------| 90 | | 14 | 1 1 | 0.0 100.0 0.0 0.0 100.0 100.0 | 91 | |--------+-------------+-------------------------------------------| 92 | | 15 | 1 2 | 0.0 100.0 0.0 0.0 100.0 100.0 | 93 | |--------+-------------+-------------------------------------------| 94 | | 16 | 1 9 | 11.1 88.9 0.0 44.4 133.3 100.0 | 95 | |--------+-------------+-------------------------------------------| 96 | | 17 | 1 6 | 0.0 83.3 16.7 0.0 100.0 100.0 | 97 | |--------+-------------+-------------------------------------------| 98 | | 18 | 1 1 | 0.0 100.0 0.0 1000.0 1100.0 100.0 | 99 | |--------+-------------+-------------------------------------------| 100 | | 19 | 1 8 | 0.0 100.0 0.0 75.0 175.0 100.0 | 101 | |--------+-------------+-------------------------------------------| 102 | | 20 | 1 8 | 0.0 75.0 25.0 0.0 100.0 100.0 | 103 | |--------+-------------+-------------------------------------------| 104 | | 21 | 1 1 | 0.0 100.0 0.0 600.0 700.0 100.0 | 105 | |--------+-------------+-------------------------------------------| 106 | | 22 | 1 5 | 0.0 40.0 60.0 0.0 100.0 100.0 | 107 | |--------+-------------+-------------------------------------------| 108 | | 23 | 1 14 | 14.3 28.6 57.1 0.0 85.7 100.0 | 109 | |--------+-------------+-------------------------------------------| 110 | | 24 | 1 8 | 0.0 75.0 25.0 0.0 100.0 100.0 | 111 | |--------+-------------+-------------------------------------------| 112 | | 25 | 1 9 | 0.0 33.3 66.7 0.0 100.0 100.0 | 113 | |--------+-------------+-------------------------------------------| 114 | | 26 | 1 10 | 10.0 80.0 10.0 30.0 120.0 100.0 | 115 | |--------+-------------+-------------------------------------------| 116 | | 27 | 1 16 | 0.0 12.5 87.5 0.0 100.0 100.0 | 117 | |--------+-------------+-------------------------------------------| 118 | | 28 | 1 12 | 0.0 16.7 83.3 0.0 100.0 100.0 | 119 | |--------+-------------+-------------------------------------------| 120 | | 29 | 1 2 | 0.0 100.0 0.0 100.0 200.0 100.0 | 121 | |--------+-------------+-------------------------------------------| 122 | | 30 | 1 14 | 0.0 28.6 71.4 0.0 100.0 100.0 | 123 | |--------+-------------+-------------------------------------------| 124 | | 31 | 1 3 | 0.0 100.0 0.0 66.7 166.7 100.0 | 125 | |--------+-------------+-------------------------------------------| 126 | | 32 | 1 2 | 0.0 100.0 0.0 300.0 400.0 100.0 | 127 | |--------+-------------+-------------------------------------------| 128 | | 33 | 1 1 | 0.0 100.0 0.0 700.0 800.0 100.0 | 129 | |--------+-------------+-------------------------------------------| 130 | | 34 | 1 41 | 4.9 56.1 39.0 0.0 95.1 100.0 | 131 | |--------+-------------+-------------------------------------------| 132 | | 35 | 1 5 | 0.0 100.0 0.0 0.0 100.0 100.0 | 133 | |--------+-------------+-------------------------------------------| 134 | | 36 | 1 6 | 0.0 83.3 16.7 0.0 100.0 100.0 | 135 | |--------+-------------+-------------------------------------------| 136 | | 37 | 1 1 |100.0 0.0 0.0 400.0 400.0 100.0 | 137 | |--------+-------------+-------------------------------------------| 138 | | 38 | 1 4 | 25.0 50.0 25.0 0.0 75.0 100.0 | 139 | |--------+-------------+-------------------------------------------| 140 | | 39 | 1 35 | 2.9 8.6 88.6 0.0 97.1 100.0 | 141 | |--------+-------------+-------------------------------------------| 142 | | 40 | 1 1 | 0.0 100.0 0.0 100.0 200.0 100.0 | 143 | |--------+-------------+-------------------------------------------| 144 | | 41 | 1 40 | 5.0 25.0 70.0 0.0 95.0 100.0 | 145 | |--------+-------------+-------------------------------------------| 146 | | 42 | 1 18 | 0.0 11.1 88.9 0.0 100.0 100.0 | 147 | |--------+-------------+-------------------------------------------| 148 | | 43 | 1 2 | 0.0 100.0 0.0 100.0 200.0 100.0 | 149 | |--------+-------------+-------------------------------------------| 150 | | 44 | 1 21 | 0.0 52.4 47.6 0.0 100.0 100.0 | 151 | |--------+-------------+-------------------------------------------| 152 | | 45 | 1 5 | 0.0 100.0 0.0 140.0 240.0 100.0 | 153 | |--------+-------------+-------------------------------------------| 154 | | 46 | 1 31 | 0.0 9.7 90.3 0.0 100.0 100.0 | 155 | |--------+-------------+-------------------------------------------| 156 | | 47 | 1 3 | 0.0 66.7 33.3 0.0 100.0 100.0 | 157 | |--------+-------------+-------------------------------------------| 158 | | 48 | 1 7 | 0.0 28.6 71.4 0.0 100.0 100.0 | 159 | |--------+-------------+-------------------------------------------| 160 | | 49 | 1 2 | 0.0 100.0 0.0 1650.0 1750.0 100.0 | 161 | |--------+-------------+-------------------------------------------| 162 | | 50 | 1 13 | 7.7 69.2 23.1 0.0 92.3 100.0 | 163 | |==================================================================| 164 | | Sum/Avg| 50 454 | 3.3 42.3 54.4 34.1 130.8 100.0 | 165 | |==================================================================| 166 | | Mean | 1.0 9.1 | 6.3 63.2 30.4 149.0 242.6 100.0 | 167 | | S.D. | 0.0 9.9 | 20.2 37.3 35.4 307.6 303.7 0.0 | 168 | | Median | 1.0 5.5 | 0.0 75.0 16.7 0.0 100.0 100.0 | 169 | `------------------------------------------------------------------' 170 | 171 | Successful Completion 172 | -------------------------------------------------------------------------------- /tests/sclite/token2id.txt: -------------------------------------------------------------------------------- 1 | a 0 2 | ac 1 3 | accumsan 2 4 | ad 3 5 | adipiscing 4 6 | aenean 5 7 | aliquam 6 8 | aliquet 7 9 | amet 8 10 | ante 9 11 | aptent 10 12 | arcu 11 13 | at 12 14 | auctor 13 15 | augue 14 16 | bibendum 15 17 | blandit 16 18 | class 17 19 | commodo 18 20 | condimentum 19 21 | congue 20 22 | consectetuer 21 23 | consequat 22 24 | conubia 23 25 | convallis 24 26 | cras 25 27 | cubilia 26 28 | cum 27 29 | curabitur 28 30 | curae 29 31 | cursus 30 32 | dapibus 31 33 | diam 32 34 | dictum 33 35 | dictumst 34 36 | dignissim 35 37 | dis 36 38 | dolor 37 39 | donec 38 40 | dui 39 41 | duis 40 42 | egestas 41 43 | eget 42 44 | eleifend 43 45 | elementum 44 46 | elit 45 47 | enim 46 48 | erat 47 49 | eros 48 50 | est 49 51 | et 50 52 | etiam 51 53 | eu 52 54 | euismod 53 55 | facilisi 54 56 | facilisis 55 57 | fames 56 58 | faucibus 57 59 | felis 58 60 | fermentum 59 61 | feugiat 60 62 | fringilla 61 63 | fusce 62 64 | gravida 63 65 | habitant 64 66 | habitasse 65 67 | hac 66 68 | hendrerit 67 69 | hymenaeos 68 70 | iaculis 69 71 | id 70 72 | imperdiet 71 73 | in 72 74 | inceptos 73 75 | integer 74 76 | interdum 75 77 | ipsum 76 78 | justo 77 79 | lacinia 78 80 | lacus 79 81 | laoreet 80 82 | lectus 81 83 | leo 82 84 | libero 83 85 | ligula 84 86 | litora 85 87 | lobortis 86 88 | lorem 87 89 | luctus 88 90 | maecenas 89 91 | magna 90 92 | magnis 91 93 | malesuada 92 94 | massa 93 95 | mattis 94 96 | mauris 95 97 | metus 96 98 | mi 97 99 | molestie 98 100 | mollis 99 101 | montes 100 102 | morbi 101 103 | mus 102 104 | nam 103 105 | nascetur 104 106 | natoque 105 107 | nec 106 108 | neque 107 109 | netus 108 110 | nibh 109 111 | nisi 110 112 | nisl 111 113 | non 112 114 | nonummy 113 115 | nostra 114 116 | nulla 115 117 | nullam 116 118 | nunc 117 119 | odio 118 120 | orci 119 121 | ornare 120 122 | parturient 121 123 | pede 122 124 | pellentesque 123 125 | penatibus 124 126 | per 125 127 | pharetra 126 128 | phasellus 127 129 | placerat 128 130 | platea 129 131 | porta 130 132 | porttitor 131 133 | posuere 132 134 | potenti 133 135 | praesent 134 136 | pretium 135 137 | primis 136 138 | proin 137 139 | pulvinar 138 140 | purus 139 141 | quam 140 142 | quis 141 143 | quisque 142 144 | rhoncus 143 145 | ridiculus 144 146 | risus 145 147 | rutrum 146 148 | sagittis 147 149 | sapien 148 150 | scelerisque 149 151 | sed 150 152 | sem 151 153 | semper 152 154 | senectus 153 155 | sit 154 156 | sociis 155 157 | sociosqu 156 158 | sodales 157 159 | sollicitudin 158 160 | suscipit 159 161 | suspendisse 160 162 | taciti 161 163 | tellus 162 164 | tempor 163 165 | tempus 164 166 | tincidunt 165 167 | torquent 166 168 | tortor 167 169 | tristique 168 170 | turpis 169 171 | ullamcorper 170 172 | ultrices 171 173 | ultricies 172 174 | urna 173 175 | ut 174 176 | varius 175 177 | vehicula 176 178 | vel 177 179 | velit 178 180 | venenatis 179 181 | vestibulum 180 182 | vitae 181 183 | vivamus 182 184 | viverra 183 185 | volutpat 184 186 | vulputate 185 187 | -------------------------------------------------------------------------------- /tests/sclite/total.txt: -------------------------------------------------------------------------------- 1 | 1.308 2 | -------------------------------------------------------------------------------- /tests/sparse_image_warp/dst.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/dst.npy -------------------------------------------------------------------------------- /tests/sparse_image_warp/flow_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/flow_0.npy -------------------------------------------------------------------------------- /tests/sparse_image_warp/flow_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/flow_2.npy -------------------------------------------------------------------------------- /tests/sparse_image_warp/img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/img.npy -------------------------------------------------------------------------------- /tests/sparse_image_warp/src.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/src.npy -------------------------------------------------------------------------------- /tests/sparse_image_warp/warped_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/warped_0.npy -------------------------------------------------------------------------------- /tests/sparse_image_warp/warped_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdrobert/pydrobert-pytorch/2d2823e03385a0fa471ff666bdd0c7c9083a9094/tests/sparse_image_warp/warped_2.npy -------------------------------------------------------------------------------- /tests/test_argcheck.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import pytest 17 | import numpy as np 18 | 19 | from pathlib import Path 20 | 21 | from pydrobert.torch import argcheck 22 | 23 | 24 | @pytest.mark.cpu 25 | @pytest.mark.parametrize( 26 | "check,val,exp", 27 | [ 28 | (argcheck.is_str, "a", "a"), 29 | (argcheck.is_str, "", ""), 30 | (argcheck.is_str, bytes("a", "utf-8"), None), 31 | (argcheck.is_int, 1, 1), 32 | (argcheck.is_int, -1, -1), 33 | (argcheck.is_int, np.uint32(1), 1), 34 | (argcheck.is_int, 1.0, None), 35 | (argcheck.is_int, np.float32(1.0), None), 36 | (argcheck.is_bool, True, True), 37 | (argcheck.is_bool, False, False), 38 | (argcheck.is_bool, 1, None), 39 | (argcheck.is_bool, "", None), 40 | (argcheck.is_float, 1.0, 1.0), 41 | (argcheck.is_float, np.inf, np.inf), 42 | (argcheck.is_float, 1, 1.0), 43 | (argcheck.is_float, np.float64(3.14), 3.14), 44 | (argcheck.is_float, np.uint8(255), 255.0), 45 | (argcheck.is_tensor, torch.ones(5), torch.ones(5)), 46 | (argcheck.is_tensor, 1, None), 47 | (argcheck.is_path, Path("."), Path(".")), 48 | (argcheck.is_path, ".", Path(".")), 49 | (argcheck.is_path, b".", None), 50 | (argcheck.is_numlike, 1, 1), 51 | (argcheck.is_numlike, 1.0, 1.0), 52 | (argcheck.is_numlike, "", None), 53 | (argcheck.is_token, "foo", "foo"), 54 | (argcheck.is_token, "", None), 55 | (argcheck.is_token, "foo bar", None), 56 | (argcheck.is_posi, 1, 1), 57 | (argcheck.is_posi, np.uint8(2), 2), 58 | (argcheck.is_posi, 1.0, None), 59 | (argcheck.is_posi, 0, None), 60 | (argcheck.is_nonposf, -1.0, -1.0), 61 | (argcheck.is_nonposf, -1, -1.0), 62 | (argcheck.is_nonposf, np.uint8(0), 0.0), 63 | (argcheck.is_nonposf, 1.0, None), 64 | (argcheck.is_closed01t, torch.arange(2), torch.arange(2)), 65 | (argcheck.is_closed01t, 1, None), 66 | (argcheck.is_closed01t, -torch.arange(2), None), 67 | (argcheck.is_open01t, torch.full((5,), 0.5), torch.full((5,), 0.5)), 68 | (argcheck.is_open01t, 0.5, None), 69 | (argcheck.is_open01t, torch.arange(2), None), 70 | (argcheck.is_file, __file__, __file__), 71 | (argcheck.is_file, Path.cwd(), None), 72 | (argcheck.is_dir, Path.cwd(), Path.cwd()), 73 | (argcheck.is_dir, __file__, None), 74 | (argcheck.is_nonempty, torch.ones(1), torch.ones(1)), 75 | (argcheck.is_nonempty, torch.ones(0), None), 76 | ], 77 | ) 78 | def test_is_type(check, val, exp): 79 | if exp is None: 80 | with pytest.raises(ValueError): 81 | check(val) 82 | else: 83 | act = check(val) 84 | assert type(exp) is type(act) 85 | if isinstance(act, torch.Tensor): 86 | assert (exp == act).all() 87 | else: 88 | assert exp == act 89 | assert check(None, allow_none=True) is None 90 | 91 | 92 | @pytest.mark.parametrize( 93 | "check,val,rest,good", 94 | [ 95 | (argcheck.is_a, 1, (int,), True), 96 | (argcheck.is_a, 1, (float,), False), 97 | (argcheck.is_in, 2, (range(10),), True), 98 | (argcheck.is_in, "1", (range(10),), False), 99 | (argcheck.is_exactly, 1, (1,), True), 100 | (argcheck.is_exactly, 1, (1.0,), False), 101 | (argcheck.is_equal, 1, (1.0,), True), 102 | (argcheck.is_equal, torch.ones(2, 1), (torch.ones(1, 2),), True), 103 | (argcheck.is_equal, torch.ones(2, 1), (torch.arange(2),), False), 104 | (argcheck.is_lt, 1, (torch.arange(2) + 2,), True), 105 | (argcheck.is_lt, 1.0, (torch.arange(2) + 1,), False), 106 | (argcheck.is_gte, torch.full((5,), np.inf), (10_000,), True), 107 | (argcheck.is_gte, 1, (1.0,), True), 108 | (argcheck.is_gte, 0, (torch.arange(2),), False), 109 | (argcheck.is_btw, 30.5, (0.1, np.inf), True), 110 | (argcheck.is_btw, 30.5, (0.1, -np.inf), False), 111 | (argcheck.is_btw_open, 1, (0.999, 1.001), True), 112 | (argcheck.is_btw_open, 1, (0.999, 1), False), 113 | (argcheck.is_btw_closed, 1, (1, 1), True), 114 | (argcheck.is_btw_closed, 1.001, (1, 1), False), 115 | (argcheck.has_ndim, torch.empty(1, 2, 3), (3,), True), 116 | (argcheck.has_ndim, torch.empty(0), (3,), False), 117 | ], 118 | ) 119 | def test_comparative(check, val, rest, good): 120 | if good: 121 | assert check(val, *rest) is val 122 | else: 123 | with pytest.raises(ValueError): 124 | check(val, *rest) 125 | assert check(None, *rest, allow_none=True) is None 126 | 127 | 128 | @pytest.mark.parametrize( 129 | "check,val,exp", 130 | [ 131 | (argcheck.as_str, 1, "1"), 132 | (argcheck.as_int, "1", 1), 133 | (argcheck.as_int, "1.1", None), 134 | (argcheck.as_bool, 0, False), 135 | (argcheck.as_bool, -1, True), 136 | (argcheck.as_tensor, 0.0, torch.tensor(0.0)), 137 | (argcheck.as_tensor, (0, 1), torch.arange(2)), 138 | (argcheck.as_tensor, "foo", None), 139 | (argcheck.as_posf, "3e10", 3e10), 140 | (argcheck.as_nat, "100", 100), 141 | (argcheck.as_nat, "1.0", None), 142 | (argcheck.as_nat, "0", None), 143 | (argcheck.as_nonnegi, "0", 0), 144 | (argcheck.as_nonnegi, "-1", None), 145 | (argcheck.as_open01, "-1", None), 146 | (argcheck.as_open01, "1e-5", 1e-5), 147 | (argcheck.as_closed01, "0", 0.0), 148 | (argcheck.as_path, ".", Path(".")), 149 | (argcheck.as_path_file, ".", None), 150 | (argcheck.as_path_dir, __file__, None), 151 | (argcheck.as_dir, Path("."), "."), 152 | (argcheck.as_file, __file__, __file__), 153 | ], 154 | ) 155 | def test_as(check, val, exp): 156 | if exp is None: 157 | with pytest.raises(TypeError): 158 | check(val) 159 | else: 160 | act = check(val) 161 | assert type(exp) is type(act) 162 | if isinstance(act, torch.Tensor): 163 | assert (exp == act).all() 164 | else: 165 | assert exp == act 166 | -------------------------------------------------------------------------------- /tests/test_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | import torch 18 | import pytest 19 | 20 | from pydrobert.torch.modules import ( 21 | ConcatSoftAttention, 22 | DotProductSoftAttention, 23 | GeneralizedDotProductSoftAttention, 24 | GlobalSoftAttention, 25 | MultiHeadedAttention, 26 | ) 27 | 28 | 29 | @pytest.mark.parametrize("dim", [0, 1]) 30 | def test_global_soft_attention(device, dim): 31 | class FirstIsBest(GlobalSoftAttention): 32 | def score(self, query, key): 33 | e = torch.full_like(key[..., 0], -float("inf")) 34 | e.narrow(self.dim, 0, 1).fill_(0.0) 35 | return e 36 | 37 | class ILoveEveryoneEqually(GlobalSoftAttention): 38 | def score(self, query, key): 39 | return torch.zeros_like(key[..., 0]) 40 | 41 | T, max_dim, max_dim_size = 12, 10, 10 42 | num_dim = torch.randint(dim + 2, max_dim + 1, (1,), device=device).item() 43 | key_shape = torch.randint( 44 | 1, max_dim_size + 1, (num_dim + 1,), device=device 45 | ).tolist() 46 | key_shape[dim] = T 47 | query_shape = key_shape[:dim] + key_shape[dim + 1 : -1] 48 | del key_shape[-2] 49 | key_lens = torch.randint(1, T + 1, query_shape[:-1], device=device) 50 | query = torch.randn(*query_shape, device=device) 51 | key = torch.randn(*key_shape, device=device) 52 | query_size = query_shape[-1] 53 | key_size = key_shape[-1] 54 | arange_shape = [1] * (num_dim - 1) 55 | arange_shape[dim] = T 56 | mask = torch.arange(T, device=device).view(*arange_shape) 57 | mask = mask < key_lens.unsqueeze(dim) 58 | key.requires_grad_(True) 59 | first_attention = FirstIsBest(query_size, key_size, dim).to(device) 60 | equal_attention = ILoveEveryoneEqually(query_size, key_size, dim).to(device) 61 | out1 = first_attention(query, key, key) 62 | assert torch.allclose(out1, key.narrow(dim, 0, 1).squeeze(dim)) 63 | out2 = first_attention(query, key, key, mask) 64 | assert torch.allclose(out1, out2) 65 | (g,) = torch.autograd.grad([out1], [key], grad_outputs=torch.ones_like(out1)) 66 | assert g.narrow(dim, 0, 1).eq(1).all() 67 | assert g.narrow(dim, 1, T - 1).eq(0).all() 68 | out1 = equal_attention(query, key, key) 69 | # the softmax introduces a slight numeric instability 70 | assert torch.allclose(out1, key.mean(dim), atol=1e-5) 71 | out2 = equal_attention(query, key, key, mask) 72 | assert not torch.allclose(out1, out2) 73 | exp = key.masked_fill(mask.unsqueeze(-1).eq(0), 0.0) 74 | exp = exp.sum(dim) 75 | exp = exp / key_lens.float().unsqueeze(-1) 76 | assert torch.allclose(out2, exp, atol=1e-5) 77 | (g,) = torch.autograd.grad([out2], [key], grad_outputs=torch.ones_like(out2)) 78 | assert g.masked_select(mask.eq(0).unsqueeze(-1)).eq(0).all() 79 | assert torch.allclose(g.sum(dim), torch.tensor(1.0, device=device), atol=1e-5) 80 | 81 | 82 | @pytest.mark.parametrize("dim", [0, 1, 2]) 83 | def test_dot_product_soft_attention(device, dim, jit_type): 84 | dim1, dim2, dim3, dim4 = 50, 30, 12, 100 85 | key_shape = (dim1, dim2, dim3, dim4) 86 | key = torch.randn(*key_shape, device=device) 87 | query_shape = key_shape[:dim] + key_shape[dim + 1 :] 88 | query = torch.zeros(*query_shape, device=device) 89 | query[..., 0] = 2.0 90 | exp = torch.nn.functional.softmax(key[..., 0], dim).unsqueeze(-1) * key 91 | exp = exp.sum(dim) 92 | attention = DotProductSoftAttention(dim4, dim, scale_factor=0.5) 93 | if jit_type == "script": 94 | attention = torch.jit.script(attention) 95 | elif jit_type == "trace": 96 | with warnings.catch_warnings(): 97 | warnings.simplefilter("ignore") 98 | attention = torch.jit.trace( 99 | attention, 100 | ( 101 | torch.empty((1,), device=device).expand(1, 1, 1, dim4), 102 | torch.empty((1,), device=device).expand(1, 1, 1, 1, dim4), 103 | torch.empty((1,), device=device).expand(1, 1, 1, 1, dim4), 104 | ), 105 | ) 106 | act = attention(query, key, key) 107 | assert torch.allclose(exp, act) 108 | 109 | 110 | @pytest.mark.cpu 111 | def test_dot_product_soft_attention_on_transformer_input(): 112 | class MatrixVersion(torch.nn.Module): 113 | """Scaled dot product attention, specifically for transformers 114 | 115 | This was blatantly ripped from `speech transformers 116 | `__. 117 | 118 | This is a more straightforward implementation of the scaled dot product 119 | attention for transformer networks. We're showing that our implementation yields 120 | the same output and gradient as this. 121 | """ 122 | 123 | def __init__(self, temperature): 124 | super(MatrixVersion, self).__init__() 125 | self.temperature = temperature 126 | self.softmax = torch.nn.Softmax(dim=2) 127 | 128 | def forward(self, q, k, v, mask=None): 129 | attn = torch.bmm(q, k.transpose(1, 2)) 130 | attn = attn / self.temperature 131 | if mask is not None: 132 | attn = attn.masked_fill(mask, -float("inf")) 133 | attn = self.softmax(attn) 134 | output = torch.bmm(attn, v) 135 | return output 136 | 137 | num_batch, len_q, len_k, d_k, d_v = 30, 40, 20, 10, 50 138 | temp = 2.0 139 | query = torch.randn(num_batch, len_q, d_k, requires_grad=True) 140 | key = torch.randn(num_batch, len_k, d_k, requires_grad=True) 141 | value = torch.randn(num_batch, len_k, d_v, requires_grad=True) 142 | matrix_attention = MatrixVersion(temp) 143 | our_attention = DotProductSoftAttention(d_k, 1, 1 / temp) 144 | out1 = matrix_attention(query, key, value) 145 | out2 = our_attention(query, key.unsqueeze(2), value.unsqueeze(2)) 146 | assert torch.allclose(out1, out2, atol=1e-5) 147 | g_q1, g_k1, g_v1 = torch.autograd.grad( 148 | [out1], [query, key, value], grad_outputs=torch.ones_like(out1) 149 | ) 150 | g_q2, g_k2, g_v2 = torch.autograd.grad( 151 | [out2], [query, key, value], grad_outputs=torch.ones_like(out2) 152 | ) 153 | assert torch.allclose(g_q1, g_q2, atol=1e-5) 154 | assert torch.allclose(g_k1, g_k2, atol=1e-5) 155 | assert torch.allclose(g_v1, g_v2, atol=1e-5) 156 | mask = torch.randint(2, (num_batch, len_q, len_k), dtype=torch.bool) 157 | out1 = matrix_attention(query, key, value, mask) 158 | out2 = our_attention( 159 | query, 160 | key.unsqueeze(2), 161 | value.unsqueeze(2), 162 | ~mask.transpose(1, 2), # we use the inverse of mask 163 | ) 164 | assert torch.allclose(out1, out2, atol=1e-5) 165 | 166 | 167 | @pytest.mark.parametrize("dim", [0, 1, 2]) 168 | @pytest.mark.parametrize("bias", [True, False]) 169 | @pytest.mark.parametrize( 170 | "layer", ["general", "concat", "multihead_general", "multihead_concat"] 171 | ) 172 | def test_learnable_soft_attention(device, dim, bias, layer, jit_type): 173 | max_dim, max_dim_size, max_num_heads = 5, 5, 10 174 | num_dim = torch.randint(dim + 2, max_dim + 1, (1,), device=device).item() 175 | # dim size must be at least 2. Otherwise, softmax will have only one 176 | # element and gradient will be zero through it 177 | key_shape = torch.randint( 178 | 2, max_dim_size + 1, (num_dim + 1,), device=device 179 | ).tolist() 180 | query_shape = key_shape[:dim] + key_shape[dim + 1 : -1] 181 | del key_shape[-2] 182 | key = torch.randn(*key_shape, device=device) 183 | query = torch.randn(*query_shape, device=device) 184 | key_size = key_shape[-1] 185 | query_size = query_shape[-1] 186 | if layer == "general": 187 | attention = GeneralizedDotProductSoftAttention(query_size, key_size, dim, bias) 188 | elif layer == "concat": 189 | attention = ConcatSoftAttention(query_size, key_size, dim, bias) 190 | elif layer.startswith("multihead_"): 191 | num_heads = torch.randint(1, max_num_heads + 1, (1,), device=device).item() 192 | d_q = max(1, query_size // num_heads) 193 | d_k = max(1, key_size // num_heads) 194 | if layer.endswith("general"): 195 | single_head_attention = GeneralizedDotProductSoftAttention( 196 | d_q, d_k, dim, bias 197 | ) 198 | elif layer.endswith("concat"): 199 | single_head_attention = ConcatSoftAttention(query_size, key_size, dim, bias) 200 | attention = MultiHeadedAttention( 201 | query_size, 202 | key_size, 203 | key_size, 204 | num_heads, 205 | single_head_attention, 206 | bias_WQ=bias, 207 | bias_WK=bias, 208 | bias_WV=bias, 209 | bias_WC=bias, 210 | ) 211 | attention = attention.to(device) 212 | torch.manual_seed(1) 213 | attention.reset_parameters() 214 | optim = torch.optim.Adam(attention.parameters(), lr=1.0) 215 | optim.zero_grad() 216 | if jit_type == "trace": 217 | with warnings.catch_warnings(): 218 | warnings.simplefilter("ignore") 219 | attention_trace = torch.jit.trace( 220 | attention, 221 | ( 222 | torch.empty(1, device=device).expand(1, 1, 1, 1, 1, query_size), 223 | torch.empty(1, device=device).expand(1, 1, 1, 1, 1, 1, key_size), 224 | torch.empty(1, device=device).expand(1, 1, 1, 1, 1, 1, key_size), 225 | ), 226 | ) 227 | elif jit_type == "script": 228 | attention_trace = torch.jit.script(attention) 229 | else: 230 | attention_trace = attention 231 | out1 = attention_trace(query, key, key) 232 | out1.mean().backward() 233 | optim.step() 234 | optim.zero_grad() 235 | out2 = attention_trace(query, key, key) 236 | assert not torch.allclose(out1, out2, atol=1e-5) 237 | torch.manual_seed(1) 238 | attention.reset_parameters() 239 | out3 = attention_trace(query, key, key) 240 | assert torch.allclose(out1, out3, atol=1e-5) 241 | -------------------------------------------------------------------------------- /tests/test_combinatorics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import pytest 19 | 20 | import pydrobert.torch.distributions as distributions 21 | import pydrobert.torch.functional as functional 22 | 23 | 24 | @pytest.mark.parametrize("tmax", [20, 66]) 25 | def test_binomial_coefficient(device, jit_type, tmax): 26 | T = torch.arange(tmax, device=device) 27 | binomial_coefficient = functional.binomial_coefficient 28 | if jit_type == "script": 29 | binomial_coefficient = torch.jit.script(binomial_coefficient) 30 | elif jit_type == "trace": 31 | binomial_coefficient = torch.jit.trace( 32 | binomial_coefficient, (torch.tensor(0), torch.tensor(0)), 33 | ) 34 | binom = binomial_coefficient(T.unsqueeze(1), T) 35 | for length in range(tmax): 36 | for count in range(tmax): 37 | if count > length: 38 | N_exp = 0 39 | else: 40 | N_exp = math.factorial(length) // ( 41 | math.factorial(count) * math.factorial(length - count) 42 | ) 43 | assert binom[length, count] == N_exp, (length, count) 44 | 45 | 46 | def test_enumerate_binary_sequences(device, jit_type): 47 | tmax = 10 48 | enumerate_binary_sequences = functional.enumerate_binary_sequences 49 | if jit_type == "script": 50 | enumerate_binary_sequences = torch.jit.script(enumerate_binary_sequences) 51 | elif jit_type == "trace": 52 | pytest.xfail("trace unsupported for enumerate_binary_sequences") 53 | support = enumerate_binary_sequences(tmax, device) 54 | assert support.shape == (2 ** tmax, tmax) 55 | assert (support.sum(0) == 2 ** (tmax - 1)).all() 56 | half = tmax // 2 57 | assert (support[: 2 ** half, half:] == 0).all() 58 | assert (support[: 2 ** half, :half].sum(0) == 2 ** (half - 1)).all() 59 | 60 | 61 | def test_enumerate_vocab_sequences(device, jit_type): 62 | tmax, vmax = 5, 4 63 | enumerate_vocab_sequences = functional.enumerate_vocab_sequences 64 | if jit_type == "script": 65 | enumerate_vocab_sequences = torch.jit.script(enumerate_vocab_sequences) 66 | elif jit_type == "trace": 67 | pytest.xfail("trace unsupported for enumerate_vocab_sequences") 68 | support = enumerate_vocab_sequences(tmax, vmax, device=device) 69 | assert support.shape == (vmax ** tmax, tmax) 70 | support_ = torch.unique(support, sorted=True, dim=0) 71 | assert support.shape == support_.shape 72 | nrange_exp = torch.arange(vmax, device=device) 73 | nrange_act, counts = support.flatten().unique(sorted=True, return_counts=True) 74 | assert counts.sum() == support.numel() 75 | assert (nrange_exp == nrange_act).all() 76 | assert (counts == support.numel() // vmax).all() 77 | for t in range(tmax): 78 | assert (support[: vmax ** t, t:] == 0).all() 79 | 80 | 81 | def test_enumerate_binary_sequences_with_cardinality(device, jit_type): 82 | tmax = 10 83 | T = torch.arange(tmax - 1, -1, -1, device=device) 84 | eb = eb_ = functional.enumerate_binary_sequences_with_cardinality 85 | if jit_type == "script": 86 | eb = eb_ = torch.jit.script(eb) 87 | elif jit_type == "trace": 88 | eb = torch.jit.trace(eb, (torch.tensor(1), torch.tensor(1))) 89 | batched, binom = eb(T.unsqueeze(-1), T) 90 | for length in range(tmax): 91 | for count in range(tmax): 92 | nonbatched = eb_(length, count).to(device) 93 | if count > length: 94 | N_exp = M_exp = 0 95 | else: 96 | if count == 0: 97 | M_exp, N_exp = 0, 1 98 | else: 99 | M_exp = math.factorial(length - 1) // ( 100 | math.factorial(count - 1) * math.factorial(length - count) 101 | ) 102 | N_exp = M_exp * length // count 103 | assert nonbatched.shape == (N_exp, length) 104 | assert (nonbatched.sum(1) == count).all() 105 | assert (nonbatched.sum(0) == M_exp).all() 106 | assert binom[tmax - length - 1, tmax - count - 1] == N_exp 107 | batched_elem = batched[tmax - length - 1, tmax - count - 1, :N_exp, :length] 108 | assert batched_elem.shape == nonbatched.shape 109 | assert (batched_elem == nonbatched).all() 110 | 111 | 112 | def test_simple_random_sampling_without_replacement(device, jit_type): 113 | tmax_max, nmax, mmax = 16, 8, 2 ** 15 114 | tmax = torch.randint(tmax_max + 1, size=(nmax,), dtype=torch.float, device=device) 115 | lmax = (torch.rand(nmax, device=device) * (tmax + 1)).floor_() 116 | 117 | srswor = distributions.SimpleRandomSamplingWithoutReplacement( 118 | lmax, tmax, tmax_max, True 119 | ) 120 | if jit_type == "script": 121 | srswor_ = torch.jit.script( 122 | functional.simple_random_sampling_without_replacement 123 | ) 124 | b = srswor_(tmax.expand(mmax, nmax), lmax.expand(mmax, nmax), tmax_max) 125 | elif jit_type == "trace": 126 | # trace doesn't support integer parameters, so we'll redefine tmax_max to the 127 | # computed default 128 | tmax_max = int(tmax.max().item()) 129 | srswor = distributions.SimpleRandomSamplingWithoutReplacement( 130 | lmax, tmax, tmax_max, True 131 | ) 132 | srswor_ = torch.jit.trace( 133 | functional.simple_random_sampling_without_replacement, 134 | [torch.ones(1), torch.zeros(1)], 135 | ) 136 | b = srswor_(tmax.expand(mmax, nmax), lmax.expand(mmax, nmax)) 137 | else: 138 | b = srswor.sample([mmax]) 139 | assert ((b == 0) | (b == 1)).all() 140 | assert (b.sum(-1) == lmax).all() 141 | tmax_mask = tmax.unsqueeze(1) > torch.arange(tmax_max, device=device) 142 | b = b * tmax_mask 143 | assert (b.sum(-1) == lmax).all() 144 | assert torch.allclose(b.float().mean(0), srswor.mean, atol=1e-2) 145 | 146 | lp_exp = [] 147 | for n in range(nmax): 148 | tmax_n, lmax_n = int(tmax[n].item()), int(lmax[n].item()) 149 | lp_exp.append( 150 | math.log( 151 | (math.factorial(tmax_n - lmax_n) * math.factorial(lmax_n)) 152 | / math.factorial(tmax_n) 153 | ) 154 | ) 155 | lp_exp = torch.tensor(lp_exp, device=device).expand(mmax, nmax) 156 | lp_act = srswor.log_prob(b) 157 | assert torch.allclose(lp_exp, lp_act) 158 | 159 | 160 | def test_simple_random_sampling_without_replacement_enumerate_support(device): 161 | tmax = 5 162 | given_count = 2 163 | total_count = torch.arange(1, tmax + 1, device=device).clamp_min_(given_count) 164 | dist = distributions.SimpleRandomSamplingWithoutReplacement( 165 | given_count, total_count 166 | ) 167 | assert not dist.has_enumerate_support 168 | total_count.fill_(tmax) 169 | dist = distributions.SimpleRandomSamplingWithoutReplacement( 170 | given_count, total_count, tmax + 1 171 | ) 172 | assert dist.has_enumerate_support 173 | support = dist.enumerate_support(True) 174 | M_exp = math.factorial(tmax - 1) // ( 175 | math.factorial(given_count - 1) * math.factorial(tmax - given_count) 176 | ) 177 | N_exp = M_exp * tmax // given_count 178 | assert support.shape == (N_exp, tmax, tmax + 1) 179 | assert (support[..., -1] == 0).all() 180 | support = support[..., :-1] 181 | assert (support.sum(-1) == given_count).all() 182 | assert (support.sum(0) == M_exp).all() 183 | -------------------------------------------------------------------------------- /tests/test_enumerate_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import pytest 17 | 18 | import pydrobert.torch.estimators as estimators 19 | 20 | 21 | @pytest.fixture(params=["log", "exp"]) 22 | def is_log(request): 23 | return request.param == "log" 24 | 25 | 26 | def test_enumerate_estimator(device, is_log): 27 | T, V = 10, 6 28 | logits = torch.randn((T, V), device=device, requires_grad=True) 29 | mask = torch.zeros_like(logits, dtype=torch.bool) 30 | mask[..., 0] = True 31 | target = torch.randint(1, V, (T,), device=device) 32 | logits_ = logits.masked_fill(mask, -float("inf")) 33 | probs = logits_.softmax(-1) 34 | 35 | def func(b: torch.Tensor) -> torch.Tensor: 36 | target_ = target.expand(b.shape[:-1]).unsqueeze(-1) 37 | probs_ = b.gather(-1, target_).squeeze(-1) 38 | if is_log: 39 | return probs_.log() 40 | else: 41 | return probs_ 42 | 43 | exp_loss = func(probs).mean() 44 | assert exp_loss != 0 45 | (exp_g,) = torch.autograd.grad(exp_loss, [logits]) 46 | 47 | logits_ = logits.masked_fill(mask, -float("inf")) 48 | probs = logits_.softmax(-1) 49 | dist = torch.distributions.OneHotCategorical(probs=probs) 50 | estimator = estimators.EnumerateEstimator(dist, func, is_log) 51 | act_loss = estimator().mean() 52 | assert torch.allclose(exp_loss, act_loss) 53 | (act_g,) = torch.autograd.grad(act_loss, [logits]) 54 | assert torch.allclose(exp_g, act_g) 55 | -------------------------------------------------------------------------------- /tests/test_feats.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import pytest 17 | 18 | from pydrobert.torch.modules import ChunkTokenSequencesBySlices, MeanVarianceNormalization, FeatureDeltas, SliceSpectData 19 | 20 | 21 | @pytest.mark.parametrize("style", ["given", "sample", "accum"]) 22 | def test_mean_var_norm(device, jit_type, style): 23 | N1, N2, N3, N4, eps = 100, 200, 5, 50, 1e-5 24 | mean = torch.randn(N3, device=device) 25 | std = torch.rand(N3, device=device).clamp_min_(eps) 26 | y_exp = torch.randn(N1, N2, N3, N4, device=device) 27 | x = y_exp * std.unsqueeze(1) + mean.unsqueeze(1) 28 | mvn = MeanVarianceNormalization( 29 | -2, mean if style == "given" else None, std if style == "given" else None, eps 30 | ) 31 | if jit_type == "script": 32 | mvn = torch.jit.script(mvn) 33 | if style == "accum": 34 | for x_n in x: 35 | mvn.accumulate(x_n) 36 | mvn.store() 37 | assert torch.allclose(mean, mvn.mean.float(), atol=1e-2) 38 | assert torch.allclose(std, mvn.std.float(), atol=1e-2) 39 | if jit_type == "trace": 40 | mvn = torch.jit.trace(mvn, (torch.empty(1, 1, N3, 1, device=device),)) 41 | y_act = mvn(x) 42 | assert torch.allclose(y_exp, y_act, atol=1e-2) 43 | 44 | 45 | @pytest.mark.parametrize("order, width", [(0, 10), (1, 3), (2, 2)]) 46 | @pytest.mark.parametrize("dim", [-3, 0, 3]) 47 | def test_feat_deltas(device, jit_type, order, width, dim): 48 | N1, N2, N3, N4 = 10, 5, 4, 2 49 | post = pytest.importorskip("pydrobert.speech.post") 50 | x = torch.randn(N1, N2, N3, N4, device=device) 51 | op = post.Deltas(order, target_axis=dim, context_window=width) 52 | exp = torch.tensor(op.apply(x.numpy(), axis=-2, in_place=True)).to(device) 53 | exp_shape = [N1, N2, N3, N4] 54 | exp_shape[dim] *= order + 1 55 | assert exp.shape == tuple(exp_shape) 56 | feat_deltas = FeatureDeltas(dim, -2, True, order, width) 57 | if jit_type == "script": 58 | feat_deltas = torch.jit.script(feat_deltas) 59 | elif jit_type == "trace": 60 | feat_deltas = torch.jit.trace(feat_deltas, (torch.empty(1, 1, 1, 1),)) 61 | act = feat_deltas(x) 62 | assert exp.shape == act.shape 63 | assert torch.allclose(exp, act, atol=1e-5) 64 | 65 | 66 | 67 | import torch 68 | import pytest 69 | from pydrobert.torch.modules import SliceSpectData 70 | 71 | 72 | @pytest.mark.parametrize("policy", ["fixed", "ali", "ref"]) 73 | @pytest.mark.parametrize("window_type", ["symmetric", "causal", "future"]) 74 | @pytest.mark.parametrize("valid_only", [True, False], ids=["valid", "invalid"]) 75 | @pytest.mark.parametrize("lobe_size", [0, 2]) 76 | def test_slice_spect_data( 77 | device, policy, window_type, valid_only, jit_type, lobe_size 78 | ): 79 | if policy == "fixed": 80 | in_lens = other_lens = torch.tensor([0, 8, 5], device=device) 81 | in_ = torch.empty((3, 11), device=device) 82 | if lobe_size == 0: 83 | # fmt: off 84 | slices_exp = torch.tensor([ 85 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], # n=1 86 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], # n=2 87 | ], device=device) 88 | # fmt: on 89 | sources_exp = torch.tensor( 90 | [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device=device 91 | ) 92 | else: 93 | assert lobe_size == 2 94 | if valid_only and window_type == "symmetric": 95 | # fmt: on 96 | slices_exp = torch.tensor([ 97 | [0, 5], [3, 8], # n=1 98 | [0, 5], # n=2 99 | ], device=device) 100 | # fmt: off 101 | sources_exp = torch.tensor([1, 1, 2], device=device) 102 | elif window_type == "symmetric": 103 | # fmt: off 104 | slices_exp = torch.tensor([ 105 | [-1, 4], [2, 7], [5, 10], # n=1 106 | [-1, 4], [2, 7], # n=2 107 | ], device=device) 108 | # fmt: on 109 | sources_exp = torch.tensor([1, 1, 1, 2, 2], device=device) 110 | elif valid_only: 111 | # fmt: off 112 | slices_exp = torch.tensor([ 113 | [0, 3], [3, 6], # n=1 114 | [0, 3], # n=2 115 | ], device=device) 116 | # fmt: on 117 | sources_exp = torch.tensor([1, 1, 2], device=device) 118 | elif window_type == "causal": 119 | # fmt: off 120 | slices_exp = torch.tensor([ 121 | [-2, 1], [1, 4], [4, 7], # n=1 122 | [-2, 1], [1, 4], # n=2 123 | ], device=device) 124 | # fmt: on 125 | sources_exp = torch.tensor([1, 1, 1, 2, 2], device=device) 126 | else: # future 127 | # fmt: off 128 | slices_exp = torch.tensor([ 129 | [0, 3], [3, 6], [6, 9], # n=1 130 | [0, 3], [3, 6], # n=2 131 | ], device=device) 132 | # fmt: on 133 | sources_exp = torch.tensor([1, 1, 1, 2, 2], device=device) 134 | elif policy == "ali": 135 | in_lens = other_lens = torch.tensor([7, 5, 9, 0], device=device) 136 | # fmt: off 137 | in_ = torch.tensor([ 138 | [0, 0, 0, 1, 1, 0, 0, 5, 5, 5], # n=0 t=7 139 | [1, 2, 2, 2, 2, 6, 6, 6, 6, 6], # n=1 t=5 140 | [3, 3, 3, 3, 1, 2, 3, 4, 4, 4], # n=2 t=9 141 | [1, 2, 3, 4, 5, 6, 7, 8, 9, 1], # n=3 t=0 142 | ], device=device) 143 | # fmt: on 144 | if lobe_size == 0: 145 | # fmt: off 146 | slices_exp = torch.tensor([ 147 | [0, 3], [3, 5], [5, 7], # n=0 148 | [0, 1], [1, 5], # n=1 149 | [0, 4], [4, 5], [5, 6], [6, 7], [7, 9], # n=2 150 | ], device=device) 151 | # fmt: on 152 | sources_exp = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2], device=device) 153 | else: 154 | assert lobe_size == 2 155 | if valid_only and window_type == "symmetric": 156 | slices_exp = torch.tensor([[0, 9]], device=device) 157 | sources_exp = torch.tensor([2], device=device) 158 | elif window_type == "symmetric": 159 | # fmt: off 160 | slices_exp = torch.tensor([ 161 | [0, 7], [0, 7], [0, 7], # n=0 162 | [0, 5], [0, 5], # n=1 163 | [0, 6], [0, 7], [0, 9], [4, 9], [5, 9], # n=2 164 | ], device=device) 165 | # fmt: on 166 | sources_exp = torch.tensor( 167 | [0, 0, 0, 1, 1, 2, 2, 2, 2, 2], device=device 168 | ) 169 | elif valid_only: 170 | # fmt: off 171 | slices_exp = torch.tensor([ 172 | [0, 7], # n=0 173 | [0, 6], [4, 7], [5, 9], # n=2 174 | ], device=device) 175 | # fmt: on 176 | sources_exp = torch.tensor([0, 2, 2, 2], device=device) 177 | elif window_type == "causal": 178 | # fmt: off 179 | slices_exp = torch.tensor([ 180 | [0, 3], [0, 5], [0, 7], # n=0 181 | [0, 1], [0, 5], # n=1 182 | [0, 4], [0, 5], [0, 6], [4, 7], [5, 9], # n=2 183 | ], device=device) 184 | # fmt: on 185 | sources_exp = torch.tensor( 186 | [0, 0, 0, 1, 1, 2, 2, 2, 2, 2], device=device 187 | ) 188 | else: 189 | # fmt: off 190 | slices_exp = torch.tensor([ 191 | [0, 7], [3, 7], [5, 7], # n=0 192 | [0, 5], [1, 5], # n=1 193 | [0, 6], [4, 7], [5, 9], [6, 9], [7, 9], # n=2 194 | ], device=device) 195 | # fmt: on 196 | sources_exp = torch.tensor( 197 | [0, 0, 0, 1, 1, 2, 2, 2, 2, 2], device=device 198 | ) 199 | else: 200 | assert policy == "ref" 201 | in_lens = torch.tensor([3, 0, 3], device=device) 202 | other_lens = torch.tensor([3, 10, 4], device=device) 203 | # fmt: off 204 | in_ = torch.tensor([ 205 | [[0, 0, 1], [0, 0, 2], [1, 1, 3]], # n=0 r=3 t=3 206 | [[1, 2, 3], [4, 5, 6], [7, 8, 9]], # n=1 r=0 t=10 207 | [[1, 2, 2], [1, 2, 5], [1, 2, -1]], # n=2 r=3 t=4 208 | ], device=device) 209 | # fmt: on 210 | if lobe_size == 0 and valid_only: 211 | # fmt: off 212 | slices_exp = torch.tensor([ 213 | [0, 1], [0, 2], [1, 3], # n=0 214 | ], device=device) 215 | # fmt: on 216 | sources_exp = torch.tensor([0, 0, 0], device=device) 217 | elif lobe_size == 0: 218 | # fmt: off 219 | slices_exp = torch.tensor([ 220 | [0, 1], [0, 2], [1, 3], # n=0 221 | [2, 5], # n=2 222 | ], device=device) 223 | # fmt: on 224 | sources_exp = torch.tensor([0, 0, 0, 2], device=device) 225 | else: 226 | assert lobe_size == 2 227 | if valid_only and window_type == "symmetric": 228 | slices_exp = torch.tensor([[0, 4]], device=device) 229 | sources_exp = torch.tensor([2], device=device) 230 | elif window_type == "symmetric": 231 | # fmt: off 232 | slices_exp = torch.tensor([ 233 | [-2, 3], [-2, 4], [-1, 5], # n=0 234 | [0, 4], [0, 7], # n=2 235 | ], device=device) 236 | # fmt: on 237 | sources_exp = torch.tensor([0, 0, 0, 2, 2], device=device) 238 | elif valid_only and window_type == "causal": 239 | slices_exp = torch.tensor([[0, 2]], device=device) 240 | sources_exp = torch.tensor([2], device=device) 241 | elif window_type == "causal": 242 | # fmt: off 243 | slices_exp = torch.tensor([ 244 | [-2, 1], [-2, 2], [-1, 3], # n=0 245 | [0, 2], [0, 5], # n=2 246 | ], device=device) 247 | # fmt: on 248 | sources_exp = torch.tensor([0, 0, 0, 2, 2], device=device) 249 | elif valid_only: 250 | # fmt: off 251 | slices_exp = torch.tensor([ 252 | [0, 3], # n=0 253 | [2, 4], # n=2 254 | ], device=device) 255 | # fmt: on 256 | sources_exp = torch.tensor([0, 2], device=device) 257 | else: # future 258 | # fmt: off 259 | slices_exp = torch.tensor([ 260 | [0, 3], [0, 4], [1, 5], # n=2 261 | [2, 4], [2, 7], # n=2 262 | ], device=device) 263 | # fmt: on 264 | sources_exp = torch.tensor([0, 0, 0, 2, 2], device=device) 265 | extract_chunk_slices = SliceSpectData( 266 | policy, window_type, valid_only, lobe_size 267 | ) 268 | if jit_type == "script": 269 | extract_chunk_slices = torch.jit.script(extract_chunk_slices) 270 | elif jit_type == "trace": 271 | extract_chunk_slices = torch.jit.trace( 272 | extract_chunk_slices, 273 | ( 274 | torch.zeros_like(in_), 275 | torch.zeros_like(in_lens), 276 | torch.zeros_like(other_lens), 277 | ), 278 | ) 279 | slices_act, sources_act = extract_chunk_slices(in_, in_lens, other_lens) 280 | assert slices_exp.shape == slices_act.shape 281 | assert sources_exp.shape == sources_act.shape 282 | assert (slices_exp == slices_act).all() 283 | assert (sources_exp == sources_act).all() 284 | 285 | 286 | @pytest.mark.parametrize("partial", [True, False], ids=['partial', 'full']) 287 | @pytest.mark.parametrize("retain", [True, False], ids=["absolute", "relative"]) 288 | def test_chunk_token_sequences_by_slices(device, partial, jit_type, retain): 289 | ref_lens = torch.tensor([0, 5, 2], device=device) 290 | # fmt: off 291 | refs = torch.tensor([ 292 | [[0, 0, 1], [1, 0, 1], [2, 0, 1], [3, 0, 1], [4, 0, 1]], # n=0 293 | [[0, 0, 2], [-1, 2, 4], [1, 4, 6], [2, -1, 7], [3, 5, 8]], # n=1 294 | [[0, 5, 4], [0, 2, 2], [0, 2, 2], [1, 2, 2], [2, 2, 2]], # n=2 295 | ], device=device) 296 | # fmt: on 297 | slices = torch.tensor([[0, 1], [3, 7], [-1, 3]], device=device) 298 | if partial: 299 | exp_chunks = [ 300 | torch.empty((0, 3), device=device, dtype=torch.long), 301 | torch.tensor([[-1, 2, 4], [1, 4, 6], [3, 5, 8]], device=device), 302 | torch.tensor([[0, 2, 2]], device=device), 303 | ] 304 | else: 305 | exp_chunks = [ 306 | torch.empty((0, 3), device=device, dtype=torch.long), 307 | torch.tensor([[1, 4, 6]], device=device), 308 | torch.tensor([[0, 2, 2]], device=device), 309 | ] 310 | if not retain: 311 | exp_chunks[1][:, 1:] += slices[1, 0] 312 | exp_chunks[2][:, 1:] += slices[2, 0] 313 | chunk_token_sequences_by_slices = ChunkTokenSequencesBySlices(partial, retain) 314 | if jit_type == "script": 315 | chunk_token_sequences_by_slices = torch.jit.script(chunk_token_sequences_by_slices) 316 | elif jit_type == "trace": 317 | chunk_token_sequences_by_slices = torch.jit.trace( 318 | chunk_token_sequences_by_slices, 319 | ( 320 | torch.empty((1, 0, 3), dtype=torch.long), 321 | torch.zeros((1, 2), dtype=torch.long), 322 | torch.zeros((1,), dtype=torch.long), 323 | ), 324 | ) 325 | act_chunks, act_lens = chunk_token_sequences_by_slices(refs, slices, ref_lens) 326 | assert len(act_lens) == len(exp_chunks) 327 | for act_chunk_n, act_lens_n, exp_chunk_n in zip(act_chunks, act_lens, exp_chunks): 328 | assert act_lens_n == exp_chunk_n.size(0) 329 | act_chunk_n = act_chunk_n[:act_lens_n] 330 | assert (act_chunk_n == exp_chunk_n).all() 331 | -------------------------------------------------------------------------------- /tests/test_metadata.py: -------------------------------------------------------------------------------- 1 | """Test package metadata""" 2 | 3 | import pydrobert.torch 4 | 5 | 6 | def test_version(): 7 | assert pydrobert.torch.__version__ != "inplace" 8 | -------------------------------------------------------------------------------- /tests/test_pad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import pytest 17 | from pydrobert.torch.modules import ChunkBySlices, PadVariable, PadMaskedSequence 18 | 19 | 20 | @pytest.mark.parametrize("mode", ["constant", "reflect", "replicate"]) 21 | @pytest.mark.parametrize("another_dim", [True, False]) 22 | def test_pad_variable(device, mode, another_dim, jit_type): 23 | N, Tmax, Tmin, F = 10, 50, 5, 30 if another_dim else 1 24 | x = torch.rand((N, Tmax, F), device=device) 25 | lens = torch.randint(Tmin, Tmax + 1, (N,), device=device) 26 | pad = torch.randint(Tmin - 1, size=(2, N), device=device) 27 | exp_padded = [] 28 | for x_n, lens_n, pad_n in zip(x, lens, pad.t()): 29 | x_n = x_n[:lens_n] 30 | padded_n = torch.nn.functional.pad( 31 | x_n.unsqueeze(0).unsqueeze(0), [0, 0] + pad_n.tolist(), mode 32 | ).view(-1, F) 33 | exp_padded.append(padded_n) 34 | pad_variable = PadVariable(mode) 35 | if jit_type == "script": 36 | pad_variable = torch.jit.script(pad_variable) 37 | elif jit_type == "trace": 38 | pad_variable = torch.jit.trace( 39 | pad_variable, 40 | ( 41 | torch.ones(1, 2), 42 | torch.full((1,), 2, dtype=torch.long), 43 | torch.ones(2, 1, dtype=torch.long), 44 | ), 45 | ) 46 | act_padded = pad_variable(x, lens, pad) 47 | for exp_padded_n, act_padded_n in zip(exp_padded, act_padded): 48 | assert torch.allclose(exp_padded_n, act_padded_n[: len(exp_padded_n)]) 49 | # quick double-check that other types work 50 | for type_ in (torch.long, torch.bool): 51 | assert pad_variable(x.to(type_), lens, pad).dtype == type_ 52 | 53 | 54 | @pytest.mark.parametrize("batch_first", [True, False]) 55 | def test_pad_masked_sequence(device, batch_first, jit_type): 56 | N1, N2, N3, N4, p = 15, 3, 11, 17, -1 57 | x = torch.rand((N1, N2, N3, N4), device=device) 58 | mask = torch.randint(2, (N1, N2), device=device, dtype=torch.bool) 59 | T, N = (N2, N1) if batch_first else (N1, N2) 60 | exp_lens = torch.empty(N, dtype=torch.long, device=device) 61 | exp_x = torch.full_like(x, p) 62 | for n in range(N): 63 | if batch_first: 64 | x_n, mask_n, ex_n = x[n], mask[n], exp_x[n] 65 | else: 66 | x_n, mask_n, ex_n = x[:, n], mask[:, n], exp_x[:, n] 67 | i = 0 68 | for j in range(T): 69 | if mask_n[j]: 70 | ex_n[i] = x_n[j] 71 | i += 1 72 | exp_lens[n] = i 73 | pad_masked_sequence = PadMaskedSequence(batch_first, float(p)) 74 | if jit_type == "script": 75 | pad_masked_sequence = torch.jit.script(pad_masked_sequence) 76 | elif jit_type == "trace": 77 | pad_masked_sequence = torch.jit.trace( 78 | pad_masked_sequence, (torch.ones(1, 1), torch.ones(1, 1, dtype=torch.bool)) 79 | ) 80 | act_x, act_lens = pad_masked_sequence(x, mask) 81 | assert act_x.shape == exp_x.shape 82 | assert act_lens.shape == exp_lens.shape 83 | assert (act_lens == exp_lens).all() 84 | assert (act_x == exp_x).all() 85 | 86 | 87 | @pytest.mark.parametrize("mode", ["constant", "reflect", "replicate"]) 88 | @pytest.mark.parametrize("another_dim", [True, False]) 89 | def test_chunk_by_slice(device, mode, another_dim, jit_type): 90 | N, Tmax, Tmin, F = 30, 20, 5, 7 if another_dim else 1 91 | lens = torch.randint(Tmin, Tmax + 1, (N,), device=device) 92 | starts = torch.randint(-Tmax + 1, Tmax, (N,), device=device) 93 | starts = torch.max(starts, -lens + 1) 94 | ends = starts + torch.randint(-1, Tmax - 1, (N,), device=device) 95 | ends = torch.min(ends, 2 * lens - 1) 96 | slices = torch.stack([starts, ends], 1) 97 | x = torch.arange(N * Tmax * F, device=device, dtype=torch.float).view(N, Tmax, F) 98 | exp_chunks = [] 99 | exp_chunk_lens = [] 100 | for x_n, lens_n, starts_n, ends_n in zip(x, lens, starts, ends): 101 | chunk_lens_n = (ends_n - starts_n).clamp_min_(0).view(1) 102 | exp_chunk_lens.append(chunk_lens_n) 103 | if chunk_lens_n == 0: 104 | exp_chunks.append(x_n[:0]) 105 | continue 106 | pad_left = (-starts_n).clamp_min_(0).item() 107 | pad_right = (ends_n - lens_n).clamp_min_(0).item() 108 | x_n = x_n[:lens_n] 109 | x_n = torch.nn.functional.pad( 110 | x_n.unsqueeze(0).unsqueeze(0), [0, 0] + [pad_left, pad_right], mode 111 | ).view(-1, F) 112 | chunks_n = x_n[starts_n + pad_left : ends_n + pad_left] 113 | exp_chunks.append(chunks_n) 114 | exp_chunk_lens = torch.cat(exp_chunk_lens) 115 | chunk_by_slices = ChunkBySlices(mode) 116 | if jit_type == "script": 117 | chunk_by_slices = torch.jit.script(chunk_by_slices) 118 | elif jit_type == "trace": 119 | chunk_by_slices = torch.jit.trace( 120 | chunk_by_slices, 121 | ( 122 | torch.ones(1, 1), 123 | torch.zeros(1, 2, dtype=torch.long), 124 | torch.full((1,), 2, dtype=torch.long), 125 | ), 126 | ) 127 | act_chunks, act_chunk_lens = chunk_by_slices(x, slices, lens) 128 | assert (exp_chunk_lens == act_chunk_lens).all() 129 | for n, (exp_chunks_n, act_chunks_n, chunk_lens_n) in enumerate( 130 | zip(exp_chunks, act_chunks, exp_chunk_lens) 131 | ): 132 | act_chunks_n = act_chunks_n[:chunk_lens_n] 133 | assert exp_chunks_n.shape == act_chunks_n.shape 134 | exp_chunks_n, act_chunks_n = exp_chunks_n.squeeze(1), act_chunks_n.squeeze(1) 135 | assert torch.allclose(exp_chunks_n, act_chunks_n[:chunk_lens_n]), ( 136 | n, 137 | x[n].squeeze(1), 138 | starts[n], 139 | ends[n], 140 | lens[n], 141 | ) 142 | -------------------------------------------------------------------------------- /tests/test_pl_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | 18 | import torch 19 | import pytest 20 | 21 | import pydrobert.torch.command_line as cmd 22 | 23 | try: 24 | import pytorch_lightning as pl 25 | import pydrobert.torch.lightning as plightning 26 | import pydrobert.param.serialization as serial 27 | except ImportError: 28 | pytest.skip( 29 | "no pytorch_lightning, pydrobert.params not available", allow_module_level=True 30 | ) 31 | 32 | 33 | @pytest.fixture(scope="session") 34 | def populate_lit_dir(request, populate_torch_dir): 35 | def _populate_lit_dir( 36 | root_dir, 37 | num_filts=5, 38 | max_ali_class=9, 39 | max_ref_class=99, 40 | train_utts=100, 41 | dev_utts=10, 42 | test_utts=20, 43 | predict_utts=None, 44 | include_ali=False, 45 | include_ref=True, 46 | with_mvn=True, 47 | **kwargs, 48 | ): 49 | params = dict( 50 | train_dir=f"{root_dir}/train", 51 | val_dir=f"{root_dir}/dev", 52 | test_dir=f"{root_dir}/test", 53 | info_path=f"{root_dir}/info.ark", 54 | ) 55 | x = [("train", train_utts), ("dev", dev_utts), ("test", test_utts)] 56 | if predict_utts is not None: 57 | params["predict_dir"] = f"{root_dir}/predict" 58 | x.append(("predict", predict_utts)) 59 | if not include_ali: 60 | max_ali_class = -1 61 | if not include_ref: 62 | max_ref_class = -1 63 | for part, num_utts in x: 64 | dir_ = os.path.join(root_dir, part) 65 | os.makedirs(dir_, exist_ok=True) 66 | populate_torch_dir( 67 | dir_, 68 | num_utts, 69 | num_filts=num_filts, 70 | max_ali_class=max_ali_class, 71 | max_ref_class=max_ref_class, 72 | include_ali=include_ali, 73 | include_ref=include_ref, 74 | **kwargs, 75 | ) 76 | with open(os.path.join(root_dir, "info.ark"), "w") as f: 77 | f.write(f"num_filts {num_filts}\n") 78 | f.write(f"max_ali_class {max_ali_class}\n") 79 | f.write(f"max_ref_class {max_ref_class}\n") 80 | 81 | if with_mvn: 82 | assert not cmd.compute_mvn_stats_for_torch_feat_data_dir( 83 | [f"{root_dir}/train/feat", f"{root_dir}/mvn.pt"] 84 | ) 85 | params["mvn_path"] = f"{root_dir}/mvn.pt" 86 | return params 87 | 88 | return _populate_lit_dir 89 | 90 | 91 | @pytest.mark.cpu 92 | def test_lit_spect_data_module_basic(temp_dir, populate_lit_dir): 93 | tN, VN, TN, N, F, A, V = 101, 11, 21, 10, 5, 9, 10 94 | params = plightning.LitSpectDataModuleParams( 95 | **populate_lit_dir(f"{temp_dir}/data", F, A, V - 1, tN, VN, TN) 96 | ) 97 | params.prefer_split = False 98 | params.initialize_missing() 99 | params.train_params.batch_size = N 100 | params.train_params.drop_last = True 101 | data = plightning.LitSpectDataModule(params) 102 | data.prepare_data() 103 | assert data.vocab_size is None 104 | assert data.test_set is None 105 | assert data.val_set is None 106 | assert data.train_set is None 107 | assert data.predict_set is None 108 | data.setup() 109 | assert data.vocab_size == V 110 | assert data.feat_size == F 111 | assert len(data.train_set) == tN 112 | assert len(data.val_set) == VN 113 | assert len(data.test_set) == TN 114 | assert len(data.predict_set) == TN 115 | pl.seed_everything(0) 116 | feat_lens, ref_lens = [], [] 117 | for feat, ref, feat_len, ref_len in data.train_dataloader(): 118 | assert feat.shape[1:] == (N, F) 119 | assert ref.shape[1:] == (N,) 120 | feat_lens.append(feat_len) 121 | ref_lens.append(ref_len) 122 | assert feat_len.shape == ref_len.shape == (N,) 123 | feat_lens_0, feat_lens = torch.cat(feat_lens), [] 124 | ref_lens_0, ref_lens = torch.cat(ref_lens), [] 125 | pl.seed_everything(0) 126 | for _, _, feat_len, ref_len in data.train_dataloader(): 127 | feat_lens.append(feat_len) 128 | ref_lens.append(ref_len) 129 | feat_lens_1 = torch.cat(feat_lens) 130 | ref_lens_1 = torch.cat(ref_lens) 131 | assert (feat_lens_0 == feat_lens_1).all() 132 | assert (ref_lens_0 == ref_lens_1).all() 133 | 134 | 135 | @pytest.mark.cpu 136 | def test_lit_spect_data_module_argparse(temp_dir, populate_lit_dir): 137 | tNN, VNN, TNN, PNN, tN, TN = 50, 40, 30, 20, 5, 10 138 | assert tNN % tN == VNN % tN == TNN % TN == PNN % TN == 0 139 | params = plightning.LitSpectDataModuleParams( 140 | **populate_lit_dir( 141 | f"{temp_dir}/data", 142 | train_utts=tNN, 143 | dev_utts=VNN, 144 | test_utts=TNN, 145 | predict_utts=PNN, 146 | ) 147 | ) 148 | params.initialize_missing() 149 | params.train_params.batch_size = params.val_params.batch_size = tN 150 | params.test_params.batch_size = TN 151 | cfg = f"{temp_dir}/conf.json" 152 | 153 | serial.register_serializer("reckless_json") 154 | json_ = params.param.serialize_parameters(mode="reckless_json") 155 | with open(cfg, "w") as f: 156 | f.write(json_) 157 | 158 | parser = argparse.ArgumentParser() 159 | plightning.LitSpectDataModule.add_argparse_args(parser) 160 | args = ["--read-data-json", cfg] 161 | namespace = parser.parse_args(args) 162 | dm = plightning.LitSpectDataModule.from_argparse_args(namespace) 163 | assert dm.params.param.pprint() == params.param.pprint() 164 | dm.prepare_data() 165 | dm.setup() 166 | assert len(dm.train_dataloader()) == tNN // tN 167 | assert len(dm.val_dataloader()) == VNN // tN 168 | assert len(dm.test_dataloader()) == TNN // TN 169 | assert len(dm.predict_dataloader()) == PNN // TN 170 | 171 | args += ["--predict-dir", f"{temp_dir}/data/test"] 172 | namespace = parser.parse_args(args) 173 | dm = plightning.LitSpectDataModule.from_argparse_args(namespace) 174 | assert dm.params.param.pprint() != params.param.pprint() 175 | dm.prepare_data() 176 | dm.setup() 177 | assert len(dm.train_dataloader()) == tNN // tN 178 | assert len(dm.val_dataloader()) == VNN // tN 179 | assert len(dm.test_dataloader()) == TNN // TN 180 | assert len(dm.predict_dataloader()) == TNN // TN 181 | -------------------------------------------------------------------------------- /tests/test_rl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import pytest 17 | 18 | from pydrobert.torch.modules import TimeDistributedReturn 19 | 20 | 21 | @pytest.mark.parametrize("batch_first", [True, False]) 22 | @pytest.mark.parametrize("gamma", [0.0, 0.95]) 23 | def test_time_distributed_return(device, batch_first, gamma, jit_type): 24 | steps, batch_size = 1000, 30 25 | r = torch.randn(steps, batch_size, device=device) 26 | exp = torch.empty_like(r) 27 | exp[-1] = r[-1] 28 | for step in range(steps - 2, -1, -1): 29 | exp[step] = r[step] + gamma * exp[step + 1] 30 | if batch_first: 31 | r = r.t().contiguous() 32 | exp = exp.t().contiguous() 33 | time_distributed_return = TimeDistributedReturn(gamma, batch_first) 34 | if jit_type == "script": 35 | time_distributed_return = torch.jit.script(time_distributed_return) 36 | elif jit_type == "trace": 37 | time_distributed_return = torch.jit.trace( 38 | time_distributed_return, (torch.empty(1, 1),) 39 | ) 40 | act = time_distributed_return(r) 41 | assert torch.allclose(exp, act, atol=1e-5) 42 | -------------------------------------------------------------------------------- /tests/test_straight_through.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Sean Robertson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import pytest 17 | 18 | from pydrobert.torch.distributions import ( 19 | GumbelOneHotCategorical, 20 | LogisticBernoulli, 21 | StraightThrough, 22 | ConditionalStraightThrough, 23 | Density, 24 | ) 25 | 26 | 27 | @pytest.mark.cpu 28 | def test_interfaces(): 29 | class GoodST1(StraightThrough): 30 | pass 31 | 32 | class GoodST2(torch.distributions.RelaxedBernoulli): 33 | def rsample(self): 34 | ... 35 | 36 | def threshold(self): 37 | ... 38 | 39 | def tlog_prob(self): 40 | ... 41 | 42 | class BadST1(torch.distributions.RelaxedBernoulli): # missing a method 43 | def rsample(self): 44 | ... 45 | 46 | def threshold(self): 47 | ... 48 | 49 | class BadST2(object): # not a distribution 50 | def rsample(self): 51 | ... 52 | 53 | def threshold(self): 54 | ... 55 | 56 | def tlog_prob(self): 57 | ... 58 | 59 | class BadST3(torch.distributions.Bernoulli): # not a relaxed distribution 60 | def rsample(self): 61 | ... 62 | 63 | def threshold(self): 64 | ... 65 | 66 | def tlog_prob(self): 67 | ... 68 | 69 | class GoodCT1(ConditionalStraightThrough): 70 | pass 71 | 72 | class GoodCT2(GoodST1): 73 | def clog_prob(self): 74 | ... 75 | 76 | def csample(self): 77 | ... 78 | 79 | class BadCT1(BadST1): 80 | def clog_prob(self): 81 | ... 82 | 83 | def csample(self): 84 | ... 85 | 86 | class BadCT2(BadST2): 87 | def clog_prob(self): 88 | ... 89 | 90 | def csample(self): 91 | ... 92 | 93 | class BadCT3(BadST3): 94 | def clog_prob(self): 95 | ... 96 | 97 | def csample(self): 98 | ... 99 | 100 | class BadCT4(GoodST1): # missing methods 101 | def clog_prob(self): 102 | ... 103 | 104 | class GoodD1(Density): 105 | pass 106 | 107 | class GoodD2(torch.distributions.Bernoulli): 108 | pass 109 | 110 | class BadD1(object): # missing methods 111 | pass 112 | 113 | assert issubclass(GoodST1, StraightThrough) 114 | assert issubclass(GoodST2, StraightThrough) 115 | assert not issubclass(BadST1, StraightThrough) 116 | assert not issubclass(BadST2, StraightThrough) 117 | assert not issubclass(BadST3, StraightThrough) 118 | assert issubclass(GoodCT1, ConditionalStraightThrough) 119 | assert issubclass(GoodCT2, ConditionalStraightThrough) 120 | assert not issubclass(BadCT1, ConditionalStraightThrough) 121 | assert not issubclass(BadCT2, ConditionalStraightThrough) 122 | assert not issubclass(BadCT3, ConditionalStraightThrough) 123 | assert not issubclass(BadCT4, ConditionalStraightThrough) 124 | assert issubclass(GoodD1, Density) 125 | assert issubclass(GoodD2, Density) 126 | assert not issubclass(BadD1, Density) 127 | 128 | 129 | def test_logistic_bernoulli(device): 130 | N, T = int(1e6), 10 131 | probs = torch.rand(T, device=device) 132 | probs[0] = 0.0 # make sure it doesn't NaN 133 | dist = LogisticBernoulli(probs=probs) 134 | z = dist.rsample([N]) 135 | assert torch.allclose(z.mean(0), dist.mean, atol=1) 136 | assert torch.allclose(z.std(0), dist.stddev, atol=1e-2) 137 | b = dist.threshold(z) 138 | assert torch.allclose(b.mean(0), probs, atol=1e-3) 139 | zz = dist.csample(b) 140 | assert torch.allclose(dist.threshold(zz), b) 141 | # E_b[E_{z|b}[z]] = E_z[z] 142 | assert torch.allclose(zz.mean(0), dist.mean, atol=1) 143 | assert torch.allclose(zz.std(0), dist.stddev, atol=1e-2) 144 | exp_log_prob = dist.log_prob(zz) 145 | act_log_prob = dist.tlog_prob(b) 146 | assert exp_log_prob.shape == act_log_prob.shape 147 | act_log_prob += dist.clog_prob(zz, b) 148 | assert exp_log_prob.shape == act_log_prob.shape 149 | assert torch.allclose(exp_log_prob, act_log_prob), ( 150 | (exp_log_prob - act_log_prob).abs().max() 151 | ) 152 | 153 | 154 | def test_gumbel_one_hot_categorical(device): 155 | N, T, V = int(1e6), 4, 3 156 | probs = torch.rand(T, V, device=device) 157 | probs[0, 0] = 0.0 # make sure it doesn't NaN 158 | probs /= probs.sum(-1, keepdim=True) 159 | dist = GumbelOneHotCategorical(probs=probs) 160 | z = dist.rsample([N]) 161 | assert torch.allclose(z.mean(0), dist.mean, atol=1) 162 | assert torch.allclose(z.std(0), dist.stddev, atol=1e-2) 163 | b = dist.threshold(z) 164 | assert torch.allclose(b.mean(0), probs, atol=1e-3) 165 | zz = dist.csample(b) 166 | assert torch.allclose(dist.threshold(zz), b) 167 | assert torch.allclose(zz.mean(0), dist.mean, atol=1) 168 | assert torch.allclose(zz.std(0), dist.stddev, atol=1e-2) 169 | exp_log_prob = dist.log_prob(zz) 170 | act_log_prob = dist.tlog_prob(b) 171 | assert exp_log_prob.shape == act_log_prob.shape 172 | act_log_prob += dist.clog_prob(zz, b) 173 | assert exp_log_prob.shape == act_log_prob.shape 174 | assert torch.allclose(exp_log_prob, act_log_prob), ( 175 | (exp_log_prob - act_log_prob).abs().max() 176 | ) 177 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py3{7{,-t151,-t170,-t181},8{,-t170,-t181},9{,-t181},10,11} 3 | isolated_build = True 4 | 5 | [gh] 6 | python = 7 | 3.7 = py37-t151 8 | 3.8 = py38-t170 9 | 3.9 = py39-t181 10 | 3.10 = py310 11 | 3.11 = py311 12 | 13 | [testenv] 14 | setenv = 15 | t151: PYTORCH_JIT = 0 16 | install_command = pip install --find-links https://download.pytorch.org/whl/cpu/torch_stable.html {opts} {packages} 17 | deps = 18 | pytest 19 | webdataset 20 | pydrobert-speech>=0.2.0 21 | t151: torch==1.5.1 22 | t170: torch==1.7.0 23 | t181: torch==1.8.1 24 | !t151-!t181: torch 25 | !t151-!t181: pytorch-lightning 26 | !t151-!t181: pydrobert-param>=0.4.0 27 | commands = 28 | chunk-torch-spect-data-dir --help 29 | compute-mvn-stats-for-torch-feat-data-dir --help 30 | compute-torch-token-data-dir-error-rates --help 31 | ctm-to-torch-token-data-dir --help 32 | get-torch-spect-data-dir-info --help 33 | print-torch-ali-data-dir-length-moments --help 34 | print-torch-ref-data-dir-length-moments --help 35 | subset-torch-spect-data-dir --help 36 | textgrids-to-torch-token-data-dir --help 37 | torch-ali-data-dir-to-torch-token-data-dir --help 38 | torch-spect-data-dir-to-wds --help 39 | torch-token-data-dir-to-ctm --help 40 | torch-token-data-dir-to-textgrids --help 41 | torch-token-data-dir-to-torch-ali-data-dir --help 42 | torch-token-data-dir-to-trn --help 43 | trn-to-torch-token-data-dir --help 44 | !t151-!t170: python -c 'import pydrobert.torch.config as c; c.USE_JIT=True; from pydrobert.torch import *' 45 | !t151-!t170: pytest --basetemp="{envtmpdir}" {posargs} 46 | t151: pytest --basetemp="{envtmpdir}" -m 'not trace and not script' {posargs} 47 | t170: pytest --basetemp="{envtmpdir}" -m 'not trace and not nojit' {posargs} 48 | --------------------------------------------------------------------------------