├── .github └── workflows │ └── publish-to-pypi.yaml ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs └── images │ └── texterrors_example.png ├── libs ├── stringvector.cc ├── stringvector.h └── texterrors_align.cc ├── pytest.ini ├── requirements.txt ├── setup.py ├── tests ├── hyptext ├── reftext └── test_functions.py └── texterrors ├── __init__.py └── texterrors.py /.github/workflows/publish-to-pypi.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package to PyPi 2 | on: push 3 | 4 | jobs: 5 | build_wheels: 6 | name: Build wheels on ${{ matrix.os }} 7 | runs-on: ${{ matrix.os }} 8 | if: startsWith(github.ref, 'refs/tags') 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, ubuntu-24.04, macos-13, macos-14, windows-latest, macos-latest] 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - uses: actions/setup-python@v5 17 | 18 | - name: Install cibuildwheel 19 | run: python -m pip install cibuildwheel==2.19.1 20 | 21 | - name: Build wheels 22 | run: python -m cibuildwheel --output-dir wheelhouse 23 | env: 24 | CIBW_SKIP: "pp3*" 25 | 26 | - uses: actions/upload-artifact@v4 27 | with: 28 | name: artifact-cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} 29 | path: ./wheelhouse/*.whl 30 | 31 | build_sdist: 32 | name: Build source distribution 33 | runs-on: ubuntu-latest 34 | if: startsWith(github.ref, 'refs/tags') 35 | steps: 36 | - uses: actions/checkout@v4 37 | 38 | - name: Build sdist 39 | run: pipx run build --sdist 40 | 41 | - uses: actions/upload-artifact@v4 42 | with: 43 | name: artifact-sdist 44 | path: dist/*.tar.gz 45 | 46 | test-wheels: 47 | name: Test wheels on ${{ matrix.os }} 48 | needs: [build_wheels] 49 | runs-on: ${{ matrix.os }} 50 | if: startsWith(github.ref, 'refs/tags') 51 | strategy: 52 | matrix: 53 | python-version: ["3.10", "3.12"] 54 | os: [ubuntu-latest, ubuntu-24.04, macos-13, macos-14, windows-latest, macos-latest] 55 | 56 | steps: 57 | - uses: actions/checkout@v4 58 | with: 59 | submodules: recursive 60 | 61 | - name: Set up Python ${{ matrix.python-version }} 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: ${{ matrix.python-version }} 65 | 66 | - name: Download Python wheels 67 | uses: actions/download-artifact@v4 68 | with: 69 | pattern: artifact-cibw-wheels-${{ matrix.os }}-* 70 | merge-multiple: true 71 | path: ./wheels 72 | 73 | - name: Install wheel 74 | shell: bash 75 | run: | 76 | python -m pip install texterrors --find-links ./wheels/ 77 | 78 | - name: Run tests 79 | shell: bash 80 | run: | 81 | pytest -v . 82 | 83 | upload_pypi: 84 | needs: [build_wheels, build_sdist, test-wheels] 85 | runs-on: ubuntu-latest 86 | if: startsWith(github.ref, 'refs/tags') 87 | steps: 88 | - uses: actions/download-artifact@v4 89 | with: 90 | pattern: artifact-* 91 | path: dist 92 | merge-multiple: true 93 | 94 | - uses: pypa/gh-action-pypi-publish@v1.5.0 95 | with: 96 | user: __token__ 97 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # For this to work there needs to be a symlink to pybind11 in the libs directory! 2 | cmake_minimum_required(VERSION 3.9.5) 3 | SET(CMAKE_CXX_STANDARD 17) 4 | SET(CMAKE_CXX_FLAGS "-O3 ") 5 | 6 | project(condutor) 7 | 8 | include_directories("libs/") 9 | 10 | add_subdirectory(libs/pybind11) 11 | pybind11_add_module(texterrors_align libs/stringvector.cc libs/texterrors_align.cc) 12 | 13 | set_target_properties(texterrors_align PROPERTIES LIBRARY_OUTPUT_NAME "texterrors_align") 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE README.md requirements.txt 2 | recursive-include libs *.* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # texterrors 3 | 4 | For calculating WER, CER, other metrics, getting detailed statistics and comparing outputs. 5 | 6 | Meant to replace older tools like `sclite` by being easy to use, modify and extend. 7 | 8 | Features: 9 | - Character aware, standard (default) and ctm based alignment 10 | - Metrics by group (for example speaker) 11 | - Comparing two hypothesis files to reference 12 | - Oracle WER 13 | - Sorting most common errors by frequency or count 14 | - Measuring performance on keywords 15 | - Measuring OOV-CER (see [https://arxiv.org/abs/2107.08091](https://arxiv.org/abs/2107.08091) ) 16 | - Colored output to inspect errors 17 | 18 | Example of colored output below (use `-c` flag). Read the white and green words to read the reference. Read the white and red words to read the hypothesis. 19 | 20 | ![Example](docs/images/texterrors_example.png) 21 | 22 | See here for [background motivation](https://ruabraun.github.io/jekyll/update/2020/11/27/On-word-error-rates.html). 23 | 24 | 25 | # Installing 26 | Requires minimum python 3.6! 27 | ``` 28 | pip install texterrors 29 | ``` 30 | The package will be installed as `texterrors` and there will be a `texterrors` script in your path. 31 | 32 | # Example 33 | 34 | The `-s` option means there will be no detailed output. Below `ref` and `hyp` are files with the first field equalling the utterance ID (therefore the `isark` flag). 35 | ``` 36 | $ texterrors -isark -s ref hyp 37 | WER: 83.33 (ins 1, del 1, sub 3 / 6) 38 | ``` 39 | 40 | You can specify an output file to save the results, probably what you want if you are getting detailed output (not using `-s`). 41 | Here we are also calculating the CER, the OOV-CER to measure the performance on the OOV words inside the `oov_list` file, and using 42 | colored output (therefore the `-c` flag). 43 | ``` 44 | $ texterrors -c -isark -cer -oov-list-f oov_list ref hyp detailed_wer_output 45 | ``` 46 | **Use `less -R` to view the colored output. Skip the `-c` flag to not use color.** 47 | 48 | Check `texterrors/__init__.py` to see functions that you may be interested in using from python. 49 | 50 | # Options you might want to use 51 | Call `texterrors -h` to see all options. 52 | 53 | `-cer`, `-isctm` - Calculate CER, Use ctms for alignment 54 | 55 | `-utt-group-map` - Should be a file which maps uttids to group, WER will be output per group (could use 56 | to get per speaker WER for example). 57 | 58 | `-second-hyp-f` - Use to compare the outputs of two different models to the reference. 59 | 60 | `-freq-sort` - Sort errors by frequency rather than count 61 | 62 | `-oov-list-f` - The CER between words aligned to the OOV words will be calculated (the OOV-CER). 63 | 64 | `-keywords-list-f` - Will calculate precision & recall of words in the file. 65 | 66 | `-oracle-wer` - Hypothesis file should have multiple entries for each utterance, oracle WER will be calculated. 67 | 68 | # Why is the WER slightly higher than in kaldi if I use `-use_chardiff`? 69 | 70 | **You can make it equal by not using the `-use_chardiff` argument.** 71 | 72 | This difference is because this tool can do character aware alignment. Across a normal sized test set this should result in a small difference. 73 | 74 | In the below example a normal WER calculation would do a one-to-one mapping and arrive at a WER of 66.67\%. 75 | 76 | | test | sentence | okay | words | ending | now | 77 | |------|----------|---------|-------|--------|-----| 78 | | test | a | sentenc | ok | endin | now | 79 | 80 | But character aware alignment would result in the following alignment: 81 | 82 | | test | - | sentence | okay | words | ending | now | 83 | |------|---|----------|------|-------|--------|-----| 84 | | test | a | sentenc | ok | - | endin | now | 85 | 86 | This results in a WER of 83.3\% because of the extra insertion and deletion. And I think one could argue this is the actually correct WER. 87 | 88 | # Changelog 89 | 90 | Recent changes: 91 | 92 | - 26.02.25 Faster alignment, better multihyp support, fixed multihyp bug. 93 | - 22.06.22 refactored internals to make them simpler, character aware alignment is off by default, added more explanations 94 | - 20.05.22 fixed bug missing regex dependency 95 | - 16.05.22 fixed bug causing wrong detailed output when there is utterance with empty reference, and utts with empty reference are not ignored 96 | - 21.04.22 insertion errors on lower line and switching colors so green is reference 97 | - 27.01.22 oracle WER and small bug fixes 98 | - 26.01.22 fixed bug causing OOV-CER feature to not work 99 | - 22.11.21 new feature to compare two outputs to reference; lots of small changes 100 | - 04.10.21 fixed bug, nocolor option, refactoring, keywords feature works properly, updated README 101 | - 22.08.21 added oracle wer feature, cost matrix creation returns cost now 102 | - 16.07.21 improves alignment based on ctms (much stricter now). 103 | 104 | TODO: use nanobind 105 | -------------------------------------------------------------------------------- /docs/images/texterrors_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuABraun/texterrors/92a17bfad48d10e9a5eddb76161b22363ae281ff/docs/images/texterrors_example.png -------------------------------------------------------------------------------- /libs/stringvector.cc: -------------------------------------------------------------------------------- 1 | #include "stringvector.h" 2 | 3 | 4 | StringVector::StringVector(const py::list& words) { 5 | int total_length = 0; 6 | for (py::handle obj : words) { 7 | const std::string word = obj.cast(); 8 | total_length += word.size(); 9 | wordend_index_.push_back(total_length); 10 | } 11 | data_.resize(total_length); 12 | int start_index = 0; 13 | for (py::handle obj : words) { 14 | const std::string word = obj.cast(); 15 | std::copy(word.begin(), word.end(), data_.begin() + start_index); 16 | start_index += word.size(); 17 | } 18 | current_index_ = 0; 19 | } 20 | 21 | StringVector::StringVector(const vector& words) { 22 | int total_length = 0; 23 | for (std::string word : words) { 24 | total_length += word.size(); 25 | wordend_index_.push_back(total_length); 26 | } 27 | current_index_ = 0; 28 | } 29 | 30 | const int StringVector::size() const { 31 | return wordend_index_.size(); 32 | } 33 | 34 | const std::string_view StringVector::operator[](const int i) const { 35 | if (i < 0 || i >= size()) { 36 | throw std::runtime_error("Invalid index"); 37 | } 38 | int start_index = 0; 39 | if (i > 0) { 40 | start_index = wordend_index_[i-1]; 41 | } 42 | int length = wordend_index_[i] - start_index; 43 | return std::string_view(data_).substr(start_index, length); 44 | } 45 | 46 | StringVector StringVector::iter() { 47 | current_index_ = 0; 48 | return *this; 49 | } 50 | 51 | const std::string_view StringVector::next() { 52 | if (current_index_ == size()) { 53 | throw pybind11::stop_iteration(); 54 | } 55 | return (*this)[current_index_++]; 56 | } 57 | 58 | std::string StringVector::Str() const { 59 | std::string repr = ""; 60 | for (int i = 0; i < size(); i++) { 61 | repr += std::string{(*this)[i]} + " "; 62 | } 63 | return repr; 64 | } 65 | 66 | StringVector::~StringVector() {} 67 | 68 | 69 | void init_stringvector(py::module &m) { 70 | py::class_(m, "StringVector") 71 | .def(py::init()) 72 | .def("size", &StringVector::size) 73 | .def("__len__", &StringVector::size) 74 | .def("__getitem__", &StringVector::operator[]) 75 | .def("__iter__", &StringVector::iter) 76 | .def("__next__", &StringVector::next) 77 | .def("__str__", &StringVector::Str); 78 | } -------------------------------------------------------------------------------- /libs/stringvector.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace py = pybind11; 9 | 10 | using namespace std; 11 | 12 | 13 | class StringVector { 14 | public: 15 | StringVector(const py::list& words); 16 | StringVector(const vector& words); 17 | ~StringVector(); 18 | 19 | const int size() const; 20 | const std::string_view operator[](const int i) const; 21 | StringVector iter(); 22 | const std::string_view next(); 23 | std::string Str() const; 24 | 25 | std::string data_; 26 | std::vector wordend_index_; 27 | int current_index_; 28 | }; -------------------------------------------------------------------------------- /libs/texterrors_align.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "stringvector.h" 9 | 10 | 11 | namespace py = pybind11; 12 | 13 | typedef int32_t int32; 14 | 15 | 16 | bool isclose(double a, double b) { 17 | return abs(a - b) < 0.0001; 18 | } 19 | 20 | struct Pair { 21 | Pair() {} 22 | Pair(int16_t f_, int16_t s_) { 23 | i = f_; 24 | j = s_; 25 | } 26 | int16_t i; 27 | int16_t j; 28 | }; 29 | 30 | 31 | int calc_edit_distance_fast(int32* cost_mat, const char* a, const char* b, 32 | const int32 M, const int32 N) { 33 | int row_length = N+1; 34 | // std::cout << "STARTING M="<< M<< " N="< 0) { 61 | cost_mat[N] = cost_mat[row_length + N]; 62 | } 63 | 64 | // std::cout << "row "< 81 | void create_lev_cost_mat(int32* cost_mat, const T* a, const T* b, 82 | const int32 M, const int32 N) { 83 | int row_length = N+1; 84 | for (int32 i = 0; i <= M; ++i) { 85 | for (int32 j = 0; j <= N; ++j) { 86 | 87 | if (i == 0 && j == 0) { 88 | cost_mat[0] = 0; 89 | continue; 90 | } 91 | if (i == 0) { 92 | int new_value = cost_mat[j - 1] + 3; 93 | cost_mat[j] = new_value; 94 | continue; 95 | } 96 | if (j == 0) { 97 | int new_value = cost_mat[(i-1) * row_length] + 3; 98 | cost_mat[i * row_length] = new_value; 99 | continue; 100 | } 101 | int32 transition_cost = a[i-1] == b[j-1] ? 0 : 1; 102 | 103 | int32 upc = cost_mat[(i-1) * row_length + j] + 3; 104 | int32 leftc = cost_mat[i * row_length + j - 1] + 3; 105 | int32 diagc = cost_mat[(i-1) * row_length + j - 1] + 4 * transition_cost; 106 | int32 cost = std::min(upc, std::min(leftc, diagc) ); 107 | cost_mat[i * row_length + j] = cost; 108 | } 109 | } 110 | } 111 | 112 | template 113 | int levdistance(const T* a, const T* b, int32 M, int32 N) { 114 | if (!M) return N; 115 | if (!N) return M; 116 | std::vector cost_mat((M+1)*(N+1)); 117 | create_lev_cost_mat(cost_mat.data(), a, b, M, N); 118 | int cost = 0; 119 | int i = M, j = N; 120 | int row_length = N+1; 121 | while (i != 0 || j != 0) { 122 | if (i == 0) { 123 | j--; 124 | cost++; 125 | } else if (j == 0) { 126 | i--; 127 | cost++; 128 | } else { 129 | int current_cost = cost_mat[i * row_length + j]; 130 | int diagc = cost_mat[(i-1) * row_length + j - 1]; 131 | int upc = cost_mat[(i-1) * row_length + j]; 132 | int leftc = cost_mat[i * row_length + j - 1]; 133 | int32 transition_cost = a[i-1] == b[j-1] ? 0 : 1; 134 | if (diagc + 4 * transition_cost == current_cost) { 135 | i--, j--; 136 | if (current_cost != diagc) cost++; 137 | } else if (upc + 3 == current_cost) { 138 | i--; 139 | cost++; 140 | } else if (leftc + 3 == current_cost) { 141 | j--; 142 | cost++; 143 | } else { 144 | std::cerr < 154 | int lev_distance(std::vector a, std::vector b) { 155 | return levdistance(a.data(), b.data(), a.size(), b.size()); 156 | } 157 | 158 | int lev_distance_str(std::string a, std::string b) { 159 | return levdistance(a.data(), b.data(), a.size(), b.size()); 160 | } 161 | 162 | int calc_edit_distance_fast_str(std::string a, std::string b) { 163 | std::vector buffer(a.size() + b.size() + 2); 164 | return calc_edit_distance_fast(buffer.data(), a.data(), b.data(), a.size(), b.size()); 165 | } 166 | 167 | enum direction{diag, move_left, up}; 168 | 169 | std::vector > get_best_path(py::array_t array, 170 | const StringVector& words_a, 171 | const StringVector& words_b, const bool use_chardiff, const bool use_fast_edit_distance=true) { 172 | auto buf = array.request(); 173 | double* cost_mat = (double*) buf.ptr; 174 | int32_t numr = array.shape()[0], numc = array.shape()[1]; 175 | std::vector char_dist_buffer; 176 | if (use_chardiff) { 177 | char_dist_buffer.resize(100); 178 | } 179 | 180 | std::vector > bestpath; 181 | int i = numr - 1, j = numc - 1; 182 | while (i != 0 || j != 0) { 183 | double upc, leftc, diagc; 184 | direction direc; 185 | if (i == 0) { 186 | direc = move_left; 187 | } else if (j == 0) { 188 | direc = up; 189 | } else { 190 | float current_cost = cost_mat[i * numc + j]; 191 | upc = cost_mat[(i-1) * numc + j]; 192 | leftc = cost_mat[i * numc + j - 1]; 193 | diagc = cost_mat[(i-1) * numc + j - 1]; 194 | const std::string_view a = words_a[i-1]; 195 | const std::string_view b = words_b[j-1]; 196 | double up_trans_cost = 1.0; 197 | double left_trans_cost = 1.0; 198 | double diag_trans_cost; 199 | if (use_chardiff) { 200 | int alen = a.size(); 201 | int blen = b.size(); 202 | if (alen >= 50 || blen >= 50) { 203 | throw std::runtime_error("Word is too long! Increase buffer"); 204 | } 205 | if (use_fast_edit_distance) { 206 | diag_trans_cost = 207 | calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 208 | } else { 209 | diag_trans_cost = 210 | levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 211 | } 212 | } else { 213 | diag_trans_cost = a == b ? 0. : 1.; 214 | } 215 | 216 | if (isclose(diagc + diag_trans_cost, current_cost)) { 217 | direc = diag; 218 | } else if (isclose(upc + up_trans_cost, current_cost)) { 219 | direc = up; 220 | } else if (isclose(leftc + left_trans_cost, current_cost)) { 221 | direc = move_left; 222 | } else { 223 | std::cout << a <<" "< > get_best_path_lists(py::array_t array, 246 | const std::vector& words_a, 247 | const std::vector& words_b, const bool use_chardiff, const bool use_fast_edit_distance=true) { 248 | auto buf = array.request(); 249 | double* cost_mat = (double*) buf.ptr; 250 | int32_t numr = array.shape()[0], numc = array.shape()[1]; 251 | std::vector char_dist_buffer; 252 | if (use_chardiff) { 253 | char_dist_buffer.resize(100); 254 | } 255 | 256 | std::vector > bestpath; 257 | int i = numr - 1, j = numc - 1; 258 | while (i != 0 || j != 0) { 259 | double upc, leftc, diagc; 260 | direction direc; 261 | if (i == 0) { 262 | direc = move_left; 263 | } else if (j == 0) { 264 | direc = up; 265 | } else { 266 | float current_cost = cost_mat[i * numc + j]; 267 | upc = cost_mat[(i-1) * numc + j]; 268 | leftc = cost_mat[i * numc + j - 1]; 269 | diagc = cost_mat[(i-1) * numc + j - 1]; 270 | const std::string& a = words_a[i-1]; 271 | const std::string& b = words_b[j-1]; 272 | double up_trans_cost = 1.0; 273 | double left_trans_cost = 1.0; 274 | double diag_trans_cost; 275 | if (use_chardiff) { 276 | int alen = a.size(); 277 | int blen = b.size(); 278 | if (alen >= 50 || blen >= 50) { 279 | throw std::runtime_error("Word is too long! Increase buffer"); 280 | } 281 | if (use_fast_edit_distance) { 282 | diag_trans_cost = 283 | calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 284 | } else { 285 | diag_trans_cost = 286 | levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 287 | } 288 | } else { 289 | diag_trans_cost = a == b ? 0. : 1.; 290 | } 291 | 292 | if (isclose(diagc + diag_trans_cost, current_cost)) { 293 | direc = diag; 294 | } else if (isclose(upc + up_trans_cost, current_cost)) { 295 | direc = up; 296 | } else if (isclose(leftc + left_trans_cost, current_cost)) { 297 | direc = move_left; 298 | } else { 299 | std::cout << a <<" "< array, py::list& bestpath_lst, std::vector texta, 321 | std::vector textb, std::vector times_a, std::vector times_b, 322 | std::vector durs_a, std::vector durs_b) { 323 | auto buf = array.request(); 324 | double* cost_mat = (double*) buf.ptr; 325 | int32_t numr = array.shape()[0], numc = array.shape()[1]; 326 | 327 | if (numr > 32000 || numc > 32000) throw std::runtime_error("Input sequences are too large!"); 328 | 329 | std::vector bestpath; 330 | int i = numr - 1, j = numc - 1; 331 | bestpath.emplace_back(i, j); 332 | while (i != 0 || j != 0) { 333 | double upc, leftc, diagc; 334 | int idx; // 0 up, 1 left, 2 diagonal 335 | if (i == 0) { 336 | idx = 1; 337 | } else if (j == 0) { 338 | idx = 0; 339 | } else { 340 | float current_cost = cost_mat[i * numc + j]; 341 | upc = cost_mat[(i-1) * numc + j]; 342 | leftc = cost_mat[i * numc + j - 1]; 343 | diagc = cost_mat[(i-1) * numc + j - 1]; 344 | 345 | double time_cost; 346 | if (i == 0 || j == 0) { 347 | time_cost = 0.; 348 | } else { 349 | double start_a = times_a[i - 1]; 350 | double start_b = times_b[j - 1]; 351 | double end_a = start_a + durs_a[i - 1]; 352 | double end_b = start_b + durs_b[j - 1]; 353 | double overlap; 354 | if (start_a > end_b) { 355 | overlap = end_b - start_a; 356 | } else if (start_b > end_a) { 357 | overlap = end_a - start_b; 358 | } else if (start_a > start_b) { 359 | double min_end = std::min(end_a, end_b); 360 | overlap = min_end - start_a; 361 | } else { 362 | double min_end = std::min(end_a, end_b); 363 | overlap = min_end - start_b; 364 | } 365 | time_cost = -overlap; 366 | } 367 | 368 | double up_trans_cost = 1. + time_cost; 369 | double left_trans_cost = 1. + time_cost; 370 | double diag_trans_cost = texta[i] == textb[j] ? 0. + time_cost : 1. + time_cost; 371 | 372 | if (isclose(upc + up_trans_cost, current_cost)) { 373 | idx = 0; 374 | } else if (isclose(leftc + left_trans_cost, current_cost)) { 375 | idx = 1; 376 | } else if (isclose(diagc + diag_trans_cost, current_cost)) { 377 | idx = 2; 378 | } else { 379 | std::cout << texta[i] <<" "< array, const StringVector& words_a, 406 | const StringVector& words_b, const bool use_chardist, const bool use_fast_edit_distance=true) { 407 | if ( array.ndim() != 2 ) 408 | throw std::runtime_error("Input should be 2-D NumPy array"); 409 | 410 | int M1 = array.shape()[0], N1 = array.shape()[1]; 411 | if (M1 - 1 != words_a.size() || N1 - 1 != words_b.size()) throw std::runtime_error("Sizes do not match!"); 412 | auto buf = array.request(); 413 | double* ptr = (double*) buf.ptr; 414 | 415 | std::vector char_dist_buffer; 416 | if (use_chardist) { 417 | char_dist_buffer.resize(100); 418 | } 419 | 420 | ptr[0] = 0; 421 | for (int32 i = 1; i < M1; i++) ptr[i*N1] = ptr[(i-1)*N1] + 1; 422 | for (int32 j = 1; j < N1; j++) ptr[j] = ptr[j-1] + 1; 423 | for(int32 i = 1; i < M1; i++) { 424 | for(int32 j = 1; j < N1; j++) { 425 | double transition_cost; 426 | if (use_chardist) { 427 | const std::string_view a = words_a[i-1]; 428 | const std::string_view b = words_b[j-1]; 429 | int alen = a.size(); 430 | int blen = b.size(); 431 | if (alen >= 50 || blen >= 50) { 432 | throw std::runtime_error("Word is too long! Increase buffer"); 433 | } 434 | if (use_fast_edit_distance) { 435 | transition_cost = 436 | calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 437 | } else { 438 | transition_cost = 439 | levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 440 | } 441 | } else { 442 | transition_cost = words_a[i-1] == words_b[j-1] ? 0. : 1.; 443 | } 444 | 445 | double upc = ptr[(i-1) * N1 + j] + 1.; 446 | double leftc = ptr[i * N1 + j - 1] + 1.; 447 | double diagc = ptr[(i-1) * N1 + j - 1] + transition_cost; 448 | double sum = std::min(upc, std::min(leftc, diagc)); 449 | ptr[i * N1 + j] = sum; 450 | } 451 | } 452 | return ptr[M1*N1 - 1]; 453 | } 454 | 455 | 456 | 457 | int calc_sum_cost_lists(py::array_t array, const std::vector& words_a, 458 | const std::vector& words_b, const bool use_chardist, const bool use_fast_edit_distance=true) { 459 | if ( array.ndim() != 2 ) 460 | throw std::runtime_error("Input should be 2-D NumPy array"); 461 | 462 | int M1 = array.shape()[0], N1 = array.shape()[1]; 463 | if (M1 - 1 != words_a.size() || N1 - 1 != words_b.size()) throw std::runtime_error("Sizes do not match!"); 464 | auto buf = array.request(); 465 | double* ptr = (double*) buf.ptr; 466 | 467 | std::vector char_dist_buffer; 468 | if (use_chardist) { 469 | char_dist_buffer.resize(100); 470 | } 471 | 472 | ptr[0] = 0; 473 | for (int32 i = 1; i < M1; i++) ptr[i*N1] = ptr[(i-1)*N1] + 1; 474 | for (int32 j = 1; j < N1; j++) ptr[j] = ptr[j-1] + 1; 475 | for(int32 i = 1; i < M1; i++) { 476 | for(int32 j = 1; j < N1; j++) { 477 | double transition_cost; 478 | if (use_chardist) { 479 | const std::string& a = words_a[i-1]; 480 | const std::string& b = words_b[j-1]; 481 | int alen = a.size(); 482 | int blen = b.size(); 483 | if (alen >= 50 || blen >= 50) { 484 | throw std::runtime_error("Word is too long! Increase buffer"); 485 | } 486 | if (use_fast_edit_distance) { 487 | transition_cost = 488 | calc_edit_distance_fast(char_dist_buffer.data(), a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 489 | } else { 490 | transition_cost = 491 | levdistance(a.data(), b.data(), a.size(), b.size()) / static_cast(std::max(a.size(), b.size())) * 1.5; 492 | } 493 | } else { 494 | transition_cost = words_a[i-1] == words_b[j-1] ? 0. : 1.; 495 | } 496 | 497 | double upc = ptr[(i-1) * N1 + j] + 1.; 498 | double leftc = ptr[i * N1 + j - 1] + 1.; 499 | double diagc = ptr[(i-1) * N1 + j - 1] + transition_cost; 500 | double sum = std::min(upc, std::min(leftc, diagc)); 501 | ptr[i * N1 + j] = sum; 502 | } 503 | } 504 | return ptr[M1*N1 - 1]; 505 | } 506 | 507 | 508 | int calc_sum_cost_ctm(py::array_t array, std::vector& texta, 509 | std::vector& textb, std::vector times_a, std::vector times_b, 510 | std::vector durs_a, std::vector durs_b) { 511 | if ( array.ndim() != 2 ) 512 | throw std::runtime_error("Input should be 2-D NumPy array"); 513 | 514 | int M = array.shape()[0], N = array.shape()[1]; 515 | if (M != texta.size() || N != textb.size()) throw std::runtime_error(" s do not match!"); 516 | auto buf = array.request(); 517 | double* ptr = (double*) buf.ptr; 518 | // std::cout << "STARTING"< end_b) { 532 | overlap = end_b - start_a; 533 | } else if (start_b > end_a) { 534 | overlap = end_a - start_b; 535 | } else if (start_a > start_b) { 536 | double min_end = std::min(end_a, end_b); 537 | overlap = min_end - start_a; 538 | } else { 539 | double min_end = std::min(end_a, end_b); 540 | overlap = min_end - start_b; 541 | } 542 | time_cost = -overlap; 543 | } 544 | 545 | a_cost = 1. + time_cost; 546 | b_cost = 1. + time_cost; 547 | transition_cost = texta[i] == textb[j] ? 0. + time_cost : 1. + time_cost; 548 | 549 | if (i == 0 && j == 0) { 550 | ptr[0] = 0; 551 | continue; 552 | } 553 | if (i == 0) { 554 | ptr[j] = ptr[j - 1] + b_cost; 555 | continue; 556 | } 557 | if (j == 0) { 558 | ptr[i * N] = ptr[(i-1) * N] + a_cost; 559 | continue; 560 | } 561 | 562 | double upc = ptr[(i-1) * N + j] + a_cost; 563 | double leftc = ptr[i * N + j - 1] + b_cost; 564 | double diagc = ptr[(i-1) * N + j - 1] + transition_cost; 565 | double sum = std::min(upc, std::min(leftc, diagc)); 566 | ptr[i * N + j] = sum; 567 | } 568 | } 569 | return ptr[(M-1) * N + N - 1]; 570 | } 571 | 572 | 573 | void init_stringvector(py::module_ &m); 574 | 575 | 576 | PYBIND11_MODULE(texterrors_align,m) { 577 | m.doc() = "pybind11 plugin"; 578 | m.def("calc_sum_cost", &calc_sum_cost, "Calculate summed cost matrix"); 579 | m.def("calc_sum_cost_lists", &calc_sum_cost_lists, "Calculate summed cost matrix"); 580 | m.def("calc_sum_cost_ctm", &calc_sum_cost_ctm, "Calculate summed cost matrix"); 581 | m.def("get_best_path", &get_best_path, "get_best_path"); 582 | m.def("get_best_path_ctm", &get_best_path_ctm, "get_best_path_ctm"); 583 | m.def("get_best_path_lists", &get_best_path_lists, "get_best_path_lists"); 584 | m.def("lev_distance", lev_distance); 585 | m.def("lev_distance", lev_distance); 586 | m.def("lev_distance_str", &lev_distance_str); 587 | m.def("calc_edit_distance_fast_str", &calc_edit_distance_fast_str); 588 | init_stringvector(m); 589 | } 590 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs = libs/* 3 | log_cli = 1 4 | log_cli_level = INFO -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pybind11 2 | plac 3 | numpy 4 | loguru 5 | termcolor 6 | Levenshtein 7 | regex 8 | pytest 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from setuptools.command.build_ext import build_ext 3 | import os 4 | import setuptools 5 | import sys 6 | 7 | __version__ = "1.0.9" 8 | 9 | 10 | class get_pybind_include(object): 11 | """Helper class to determine the pybind11 include path 12 | The purpose of this class is to postpone importing pybind11 13 | until it is actually installed, so that the ``get_include()`` 14 | method can be invoked. """ 15 | 16 | def __init__(self, user=False): 17 | self.user = user 18 | 19 | def __str__(self): 20 | import pybind11 21 | return pybind11.get_include(self.user) 22 | 23 | 24 | ext_modules = [ 25 | Extension( 26 | "texterrors_align", 27 | ["libs/texterrors_align.cc", "libs/stringvector.cc"], 28 | include_dirs=[ 29 | # Path to pybind11 headers 30 | get_pybind_include(), 31 | get_pybind_include(user=True), 32 | "libs/", 33 | ], 34 | language="c++", 35 | ) 36 | ] 37 | 38 | 39 | def has_flag(compiler, flagname): 40 | """Return a boolean indicating whether a flag name is supported on 41 | the specified compiler. 42 | """ 43 | import tempfile 44 | 45 | with tempfile.NamedTemporaryFile("w", suffix=".cpp") as f: 46 | f.write("int main (int argc, char **argv) { return 0; }") 47 | try: 48 | compiler.compile([f.name], extra_postargs=[flagname]) 49 | except setuptools.distutils.errors.CompileError: 50 | return False 51 | return True 52 | 53 | class BuildExt(build_ext): 54 | """A custom build extension for adding compiler-specific options.""" 55 | 56 | c_opts = {"msvc": ["/EHsc"], "unix": []} 57 | 58 | if sys.platform == "darwin": 59 | c_opts["unix"] += ["-stdlib=libc++", "-mmacosx-version-min=10.7"] 60 | 61 | def build_extensions(self): 62 | ct = self.compiler.compiler_type 63 | opts = self.c_opts.get(ct, []) 64 | if ct == "unix": 65 | opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) 66 | opts.append('-std=c++17') 67 | if has_flag(self.compiler, "-fvisibility=hidden"): 68 | opts.append("-fvisibility=hidden") 69 | elif ct == "msvc": 70 | opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) 71 | opts.append('/std:c++17') 72 | for ext in self.extensions: 73 | ext.extra_compile_args = opts 74 | build_ext.build_extensions(self) 75 | 76 | 77 | base_dir = os.path.dirname(os.path.realpath(__file__)) 78 | def get_requires(): 79 | req_path = os.path.join(base_dir, 'requirements.txt') 80 | install_requires = open(req_path).read().splitlines() 81 | return install_requires 82 | 83 | 84 | with open(os.path.join(base_dir, "README.md")) as fh: 85 | long_description = fh.read() 86 | 87 | 88 | setup( 89 | name="texterrors", 90 | version=__version__, 91 | author="Rudolf A Braun", 92 | author_email="rab014@gmail.com", 93 | packages=["texterrors"], 94 | license="Apache-2.0 License", 95 | url='https://github.com/RuABraun/texterrors', 96 | description="For WER", 97 | long_description=long_description, 98 | long_description_content_type="text/markdown", 99 | ext_modules=ext_modules, 100 | cmdclass={"build_ext": BuildExt}, 101 | entry_points={'console_scripts': ['texterrors=texterrors.texterrors:cli']}, 102 | install_requires=get_requires(), 103 | setup_requires=['pybind11'], 104 | python_requires='>=3.6' 105 | ) 106 | -------------------------------------------------------------------------------- /tests/test_functions.py: -------------------------------------------------------------------------------- 1 | """ Run command: PYTHONPATH=. pytest . 2 | """ 3 | import os 4 | import io 5 | import time 6 | import sys 7 | from loguru import logger 8 | 9 | import Levenshtein as levd 10 | from texterrors import texterrors 11 | from texterrors.texterrors import StringVector 12 | from dataclasses import dataclass 13 | import difflib 14 | from loguru import logger 15 | 16 | logger.remove() 17 | logger.add(sys.stderr, level="INFO") 18 | 19 | def show_diff(text1, text2): 20 | # Split the strings into lines to compare them line by line 21 | lines1 = text1.splitlines() 22 | lines2 = text2.splitlines() 23 | 24 | # Create a Differ object and calculate the differences 25 | differ = difflib.Differ() 26 | diff = list(differ.compare(lines1, lines2)) 27 | 28 | # Optionally, you can filter out lines that haven't changed 29 | diff = [line for line in diff if line[0] != ' '] 30 | 31 | # Join the result back into a single string and return it 32 | return '\n'.join(diff) 33 | 34 | 35 | def test_levd(): 36 | pairs = ['a', '', '', 'a', 'MOZILLA', 'MUSIAL', 'ARE', 'MOZILLA', 'TURNIPS', 'TENTH', 'POSTERS', 'POSTURE'] 37 | for a, b in zip(pairs[:-1:2], pairs[1::2]): 38 | d1 = texterrors.lev_distance(a, b) 39 | d2 = levd.distance(a, b) 40 | assert d1 == d2, f'{a} {b} {d1} {d2}' 41 | 42 | 43 | # def test_calc_edit_distance_fast(): 44 | # pairs = ['a', '', '', 'a', 'MOZILLA', 'MUSIAL', 'ARE', 'MOZILLA', 'TURNIPS', 'TENTH', 'POSTERS', 'POSTURE'] 45 | # for a, b in zip(pairs[:-1:2], pairs[1::2]): 46 | # d1 = texterrors.calc_edit_distance_fast(a, b) 47 | # d2 = levd.distance(a, b) 48 | # assert d1 == d2, f'{a} {b} fasteditdist={d1} ref={d2}' 49 | 50 | 51 | def calc_wer(ref, b): 52 | cnt = 0 53 | err = 0 54 | for w1, w2 in zip(ref, b): 55 | if w1 != '': 56 | cnt += 1 57 | if w1 != w2: 58 | err += 1 59 | return 100. * (err / cnt) 60 | 61 | 62 | def test_wer(): 63 | ref = StringVector('IN THE DISCOTHEQUE THE DJ PLAYED PROGRESSIVE HOUSE MUSIC AND TRANCE'.split()) 64 | hyp = StringVector('IN THE DISCO TAK THE D J PLAYED PROGRESSIVE HOUSE MUSIC AND TRANCE'.split()) 65 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False) 66 | wer = calc_wer(ref_aligned, hyp_aligned) 67 | assert round(wer, 2) == 36.36, round(wer, 2) 68 | 69 | ref = StringVector('IT FORMS PART OF THE SOUTH EAST DORSET CONURBATION ALONG THE ENGLISH CHANNEL COAST'.split()) 70 | hyp = StringVector("IT FOLLOWS PARDOFELIS LOUSES DORJE THAT COMORE H O LONELY ENGLISH GENOME COTA'S".split()) 71 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False) 72 | wer = calc_wer(ref_aligned, hyp_aligned) 73 | assert round(wer, 2) == 85.71, round(wer, 2) 74 | 75 | ref = StringVector('THE FILM WAS LOADED INTO CASSETTES IN A DARKROOM OR CHANGING BAG'.split()) 76 | hyp = StringVector("THE FILM WAS LOADED INTO CASSETTES IN A DARK ROOM OR CHANGING BAG".split()) 77 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False) 78 | wer = calc_wer(ref_aligned, hyp_aligned) 79 | assert round(wer, 2) == 16.67, round(wer, 2) 80 | 81 | ref = StringVector('GEPHYRIN HAS BEEN SHOWN TO BE NECESSARY FOR GLYR CLUSTERING AT INHIBITORY SYNAPSES'.split()) 82 | hyp = StringVector("THE VIDEOS RISHIRI TUX BINOY CYSTIDIA PHU LIAM CHOLESTEROL ET INNIT PATRESE SYNAPSES".split()) 83 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False, use_chardiff=True) 84 | wer = calc_wer(ref_aligned, hyp_aligned) 85 | assert round(wer, 2) == 100.0, round(wer, 2) # kaldi gets 92.31 ! but has worse alignment 86 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False, use_chardiff=False) 87 | wer = calc_wer(ref_aligned, hyp_aligned) 88 | assert round(wer, 2) == 92.31, round(wer, 2) 89 | 90 | ref = StringVector('test sentence okay words ending now'.split()) 91 | hyp = StringVector("test a sentenc ok endin now".split()) 92 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False, use_chardiff=True) 93 | wer = calc_wer(ref_aligned, hyp_aligned) 94 | assert round(wer, 2) == 83.33, round(wer, 2) # kaldi gets 66.67 ! but has worse alignment 95 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False, use_chardiff=False) 96 | wer = calc_wer(ref_aligned, hyp_aligned) 97 | assert round(wer, 2) == 66.67, round(wer, 2) 98 | 99 | ref = StringVector('speedbird eight six two'.split()) 100 | hyp = StringVector('hello speedbird six two'.split()) 101 | ref_aligned, hyp_aligned, _ = texterrors.align_texts(ref, hyp, False, use_chardiff=True) 102 | assert ref_aligned[0] == '' 103 | wer = calc_wer(ref_aligned, hyp_aligned) 104 | assert round(wer, 2) == 50.0, round(wer, 2) # kaldi gets 66.67 ! but has worse alignment 105 | 106 | 107 | def test_oov_cer(): 108 | oov_set = {'airport'} 109 | ref_aligned = 'the missing word is airport okay'.split() 110 | hyp_aligned = 'the missing word is air port okay'.split() 111 | err, cnt = texterrors.get_oov_cer(ref_aligned, hyp_aligned, oov_set) 112 | assert round(err / cnt, 2) == 0.14, round(err / cnt, 2) 113 | 114 | ref_aligned = 'the missing word is airport okay'.split() 115 | hyp_aligned = 'the missing word is airport okay'.split() 116 | err, cnt = texterrors.get_oov_cer(ref_aligned, hyp_aligned, oov_set) 117 | assert err / cnt == 0., err / cnt 118 | 119 | 120 | def test_seq_distance(): 121 | a, b = 'a b', 'a b' 122 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 123 | assert d == 0 124 | 125 | a, b = 'a b', 'a c' 126 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 127 | assert d == 1 128 | 129 | a, b = 'a b c', 'a b d' 130 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 131 | assert d == 1 132 | 133 | a, b = 'a b c', 'a b d e' 134 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 135 | assert d == 2 136 | 137 | a, b = 'a b c', 'd e f g' 138 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 139 | assert d == 4 140 | 141 | a, b = 'ça va très bien', 'ça ne va pas' 142 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 143 | assert d == 3 144 | 145 | a, b = 'ça ne va pas', 'merci ça va' 146 | d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split())) 147 | assert d == 3 148 | 149 | 150 | @dataclass 151 | class Utt: 152 | uid: str 153 | words: list 154 | times: list = None 155 | durs: list = None 156 | 157 | 158 | def create_inp(lines): 159 | utts = {} 160 | for line in lines: 161 | i, line = line.split(maxsplit=1) 162 | utts[i] = Utt(i, StringVector(line.split())) 163 | return utts 164 | 165 | 166 | def test_process_output(): 167 | reflines = ['1 zum beispiel work shops wo wir anbieten'] 168 | hyplines = ['1 zum beispiel work shop sommer anbieten'] 169 | refs = create_inp(reflines) 170 | hyps = create_inp(hyplines) 171 | 172 | buffer = io.StringIO() 173 | texterrors.process_output(refs, hyps, buffer, ref_file='A', hyp_file='B', nocolor=True) 174 | output = buffer.getvalue() 175 | 176 | ref = """\"A\" is treated as reference, \"B\" as hypothesis. Errors are capitalized. 177 | Per utt details: 178 | 1 179 | zum beispiel work SHOPS WO WIR anbieten 180 | SHOP * SOMMER 181 | 182 | WER: 42.9 (ins 0, del 1, sub 2 / 7) 183 | SER: 100.0 184 | 185 | Insertions: 186 | 187 | Deletions (second number is word count total): 188 | wo\t1\t1 189 | 190 | Substitutions (reference>hypothesis, second number is reference word count total): 191 | shops>shop\t1\t1 192 | wir>sommer\t1\t1 193 | """ 194 | assert output == ref 195 | 196 | def test_process_output_multi(): 197 | reflines = ['0 telefonat mit frau spring klee vom siebenundzwanzigsten august einundzwanzig ich erkläre frau spring klee dass die bundes gerichtliche recht sprechung im zusammen hang mit dem unfall begriff beziehungsweise dem ungewöhnlichen äusseren faktor wie auch bezüglich der unfall ähnlichen körper schädigungen insbesondere die analogie zu meniskus rissen klar geregelt ist'] 198 | hypalines = ['0 telefonat mit frau sprinkler vom siebenundzwanzigsten august einundzwanzig ich erkläre frau sprinkle dass die bundes gerichtliche recht sprechung im zusammen hang mit dem unfall begriff beziehungsweise dem ungewöhnlichen äusseren faktoren wie auch bezüglich der unfall ähnlichen körper schädigungen insbesondere die analogie zum meniskus rissen klar geregelt ist\''] 199 | hypblines = ['0 telefonat mit frau sprinkle vom siebenundzwanzigsten august einundzwanzig ich erkläre frau sprinkle dass die bundes gerichtliche recht sprechung im zusammen hang mit dem unfall begriff beziehungsweise dem ungewöhnlichen äusseren faktors wie auch bezüglich der unfall ähnlichen körper schädigungen insbesondere die analogie zum meniskus riss en klar geregelt ist ok'] 200 | refs = create_inp(reflines) 201 | hypa = create_inp(hypalines) 202 | hypb = create_inp(hypblines) 203 | buffer = io.StringIO() 204 | texterrors.process_multiple_outputs(refs, hypa, hypb, buffer, 10, False, False, 'ref', 'hypa', 'hypb', terminal_width=180, usecolor=True) 205 | output = buffer.getvalue() 206 | ref = """Per utt details, order is "ref", "hypa", "hypb": 207 | 0 208 | telefonat mit frau \x1b[32mspring\x1b[0m \x1b[32mklee\x1b[0m vom siebenundzwanzigsten august einundzwanzig ich erkläre frau \x1b[32mspring\x1b[0m \x1b[32mklee\x1b[0m dass die bundes gerichtliche recht sprechung im zusammen hang mit 209 | telefonat mit frau \x1b[31msprinkler\x1b[0m vom siebenundzwanzigsten august einundzwanzig ich erkläre frau \x1b[31msprinkle\x1b[0m dass die bundes gerichtliche recht sprechung im zusammen hang mit 210 | telefonat mit frau \x1b[31msprinkle\x1b[0m vom siebenundzwanzigsten august einundzwanzig ich erkläre frau \x1b[31msprinkle\x1b[0m dass die bundes gerichtliche recht sprechung im zusammen hang mit 211 | dem unfall begriff beziehungsweise dem ungewöhnlichen äusseren \x1b[32mfaktor\x1b[0m wie auch bezüglich der unfall ähnlichen körper schädigungen insbesondere die analogie \x1b[32mzu\x1b[0m meniskus 212 | dem unfall begriff beziehungsweise dem ungewöhnlichen äusseren \x1b[31mfaktoren\x1b[0m wie auch bezüglich der unfall ähnlichen körper schädigungen insbesondere die analogie \x1b[31mzum\x1b[0m meniskus 213 | dem unfall begriff beziehungsweise dem ungewöhnlichen äusseren \x1b[31mfaktors\x1b[0m wie auch bezüglich der unfall ähnlichen körper schädigungen insbesondere die analogie \x1b[31mzum\x1b[0m meniskus \x1b[31mriss\x1b[0m 214 | \x1b[32mrissen\x1b[0m klar geregelt \x1b[32mist\x1b[0m 215 | rissen klar geregelt \x1b[31mist'\x1b[0m 216 | \x1b[31men\x1b[0m klar geregelt ist \x1b[31mok\x1b[0m 217 | 218 | Results with file hypa 219 | WER: 14.3 (ins 0, del 2, sub 5 / 49) 220 | SER: 100.0 221 | 222 | Insertions: 223 | 224 | Deletions (second number is word count total): 225 | spring\t2\t2 226 | 227 | Substitutions (reference>hypothesis, second number is reference word count total): 228 | klee>sprinkler\t1\t2 229 | klee>sprinkle\t1\t2 230 | faktor>faktoren\t1\t1 231 | zu>zum\t1\t1 232 | ist>ist'\t1\t1 233 | --- 234 | 235 | Results with file hypb 236 | WER: 18.4 (ins 2, del 2, sub 5 / 49) 237 | SER: 100.0 238 | 239 | Insertions: 240 | riss\t1 241 | ok\t1 242 | 243 | Deletions (second number is word count total): 244 | spring\t2\t2 245 | 246 | Substitutions (reference>hypothesis, second number is reference word count total): 247 | klee>sprinkle\t2\t2 248 | faktor>faktors\t1\t1 249 | zu>zum\t1\t1 250 | rissen>en\t1\t1 251 | --- 252 | 253 | Difference between outputs: 254 | 255 | Insertions: 256 | riss\t1 257 | ist\t1 258 | 259 | Deletions (second number is word count total): 260 | 261 | Substitutions (reference>hypothesis, second number is reference word count total): 262 | sprinkler>sprinkle\t1\t1 263 | faktoren>faktors\t1\t1 264 | rissen>en\t1\t1 265 | ist'>ok\t1\t1 266 | """ 267 | print(ref, file=open('ref', 'w')) 268 | print(output, file=open('output', 'w')) 269 | assert ref == output, show_diff(ref, output) 270 | 271 | 272 | def test_process_output_colored(): 273 | reflines = ['1 den asu flash würde es sonst auch in allen drei sch- in allen drei sprachen ist der verfügbar ähm jetzt für uns habe ich gedacht reicht es ja auf deutsch he'] 274 | hyplines = ['1 ah der anzug fleisch würde sonst auch in allen drei ist in allen drei sprachen verfügbar ähm jetzt für uns habe ich gedacht reicht sie auch auf deutsch he'] 275 | refs = create_inp(reflines) 276 | hyps = create_inp(hyplines) 277 | 278 | buffer = io.StringIO() 279 | texterrors.process_output(refs, hyps, buffer, ref_file='A', hyp_file='B', nocolor=False, terminal_width=80) 280 | output = buffer.getvalue() 281 | ref = """\"A\" is treated as reference (white and green), \"B\" as hypothesis (white and red). 282 | Per utt details: 283 | 1 284 | \x1b[32mden\x1b[0m \x1b[32masu\x1b[0m \x1b[32mflash\x1b[0m würde \x1b[32mes\x1b[0m sonst auch in allen drei \x1b[32msch-\x1b[0m in allen drei 285 | \x1b[31mah\x1b[0m \x1b[31mder\x1b[0m \x1b[31manzug\x1b[0m \x1b[31mfleisch\x1b[0m \x1b[31mist\x1b[0m 286 | sprachen \x1b[32mist\x1b[0m \x1b[32mder\x1b[0m verfügbar ähm jetzt für uns habe ich gedacht reicht \x1b[32mes\x1b[0m \x1b[32mja\x1b[0m 287 | \x1b[31msie\x1b[0m \x1b[31mauch\x1b[0m 288 | auf deutsch he 289 | 290 | 291 | WER: 32.3 (ins 1, del 3, sub 6 / 31) 292 | SER: 100.0 293 | 294 | Insertions: 295 | ah\t1 296 | 297 | Deletions (second number is word count total): 298 | es\t1\t2 299 | ist\t1\t1 300 | der\t1\t1 301 | 302 | Substitutions (reference>hypothesis, second number is reference word count total): 303 | den>der\t1\t1 304 | asu>anzug\t1\t1 305 | flash>fleisch\t1\t1 306 | sch->ist\t1\t1 307 | es>sie\t1\t2 308 | ja>auch\t1\t1 309 | """ 310 | print(ref, file=open('ref', 'w')) 311 | print(output, file=open('output', 'w')) 312 | assert ref == output 313 | 314 | 315 | def test_cli_basic(): 316 | ref_f = 'testref' 317 | hyp_f = 'testhyp' 318 | with open(ref_f, 'w') as fh: 319 | fh.write('1 zum beispiel work shops wo wir anbieten') 320 | with open(hyp_f, 'w') as fh: 321 | fh.write('1 zum beispiel work shop sommer anbieten') 322 | outf = 'testout' 323 | 324 | texterrors.main(ref_f, hyp_f, outf, isark=True, usecolor=False) 325 | output = open(outf).read() 326 | os.remove(ref_f) 327 | os.remove(hyp_f) 328 | os.remove(outf) 329 | ref = f"""\"{ref_f}\" is treated as reference, \"{hyp_f}\" as hypothesis. Errors are capitalized. 330 | Per utt details: 331 | 1 332 | zum beispiel work SHOPS WO WIR anbieten 333 | * SHOP SOMMER 334 | 335 | WER: 42.9 (ins 0, del 1, sub 2 / 7) 336 | SER: 100.0 337 | 338 | Insertions: 339 | 340 | Deletions (second number is word count total): 341 | shops\t1\t1 342 | 343 | Substitutions (reference>hypothesis, second number is reference word count total): 344 | wo>shop\t1\t1 345 | wir>sommer\t1\t1 346 | """ 347 | assert output == ref 348 | 349 | 350 | def test_speed(): 351 | import time 352 | import sys 353 | logger.remove() 354 | logger.add(sys.stdout, level='INFO') 355 | ref = create_inp(open('tests/reftext').read().splitlines()) 356 | hyp = create_inp(open('tests/hyptext').read().splitlines()) 357 | # import cProfile 358 | # pr = cProfile.Profile() 359 | # pr.enable() 360 | 361 | buffer = io.StringIO() 362 | start_time = time.perf_counter() 363 | texterrors.process_output(ref, hyp, fh=buffer, ref_file='ref', hyp_file='hyp', 364 | skip_detailed=True, use_chardiff=True, debug=False) 365 | process_time = time.perf_counter() - start_time 366 | 367 | # pr.disable() 368 | # pr.dump_stats('speed.prof') 369 | 370 | logger.info(f'Processing time for speed test is {process_time}') 371 | assert process_time < 2. 372 | 373 | -------------------------------------------------------------------------------- /texterrors/__init__.py: -------------------------------------------------------------------------------- 1 | from .texterrors import align_texts, process_lines, lev_distance, get_oov_cer, align_texts_ctm, seq_distance, \ 2 | process_output, process_multiple_outputs, calc_edit_distance_fast -------------------------------------------------------------------------------- /texterrors/texterrors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import shutil 3 | import sys 4 | from collections import defaultdict 5 | from dataclasses import dataclass, field 6 | from itertools import chain 7 | from typing import List, Tuple, Dict 8 | 9 | import numpy as np 10 | import plac 11 | import regex as re 12 | import texterrors_align 13 | from texterrors_align import StringVector 14 | from loguru import logger 15 | from termcolor import colored 16 | 17 | OOV_SYM = '' 18 | CPP_WORDS_CONTAINER = True 19 | 20 | 21 | def convert_to_int(lst_a, lst_b, dct): 22 | def convert(lst, dct_syms): 23 | intlst = [] 24 | for w in lst: 25 | if w not in dct: 26 | i = max(v for v in dct_syms.values() if isinstance(v, int)) + 1 27 | dct_syms[w] = i 28 | dct_syms[i] = w 29 | intlst.append(dct_syms[w]) 30 | return intlst 31 | int_a = convert(lst_a, dct) 32 | int_b = convert(lst_b, dct) 33 | return int_a, int_b 34 | 35 | 36 | def lev_distance(a, b): 37 | """ This function assumes that elements of a and b are fixed width. """ 38 | if isinstance(a, str): 39 | return texterrors_align.lev_distance_str(a, b) 40 | else: 41 | return texterrors_align.lev_distance(a, b) 42 | 43 | 44 | def calc_edit_distance_fast(a, b): 45 | return texterrors_align.calc_edit_distance_fast_str(a, b) 46 | 47 | 48 | def seq_distance(a, b): 49 | """ This function is for when a and b have strings as elements (variable length). """ 50 | assert isinstance(a, StringVector) and isinstance(b, StringVector), 'Input types should be of type StringVector!' 51 | len_a = len(a) 52 | len_b = len(b) 53 | summed_cost = np.zeros((len_a + 1, len_b + 1), dtype=np.float64, order="C") 54 | cost = texterrors_align.calc_sum_cost(summed_cost, a, b, False, True) 55 | return cost 56 | 57 | 58 | def _align_texts(words_a, words_b, use_chardiff, debug, insert_tok): 59 | summed_cost = np.zeros((len(words_a) + 1, len(words_b) + 1), dtype=np.float64, 60 | order="C") 61 | 62 | if debug: 63 | print(words_a) 64 | print(words_b) 65 | if CPP_WORDS_CONTAINER: 66 | cost = texterrors_align.calc_sum_cost(summed_cost, words_a, words_b, use_chardiff, True) 67 | else: 68 | cost = texterrors_align.calc_sum_cost_lists(summed_cost, words_a, words_b, use_chardiff, True) 69 | 70 | if debug: 71 | np.set_printoptions(linewidth=300) 72 | np.savetxt('summedcost', summed_cost, fmt='%.3f', delimiter='\t') 73 | 74 | if CPP_WORDS_CONTAINER: 75 | best_path_reversed = texterrors_align.get_best_path(summed_cost, 76 | words_a, words_b, use_chardiff, True) 77 | else: 78 | best_path_reversed = texterrors_align.get_best_path_lists(summed_cost, 79 | words_a, words_b, use_chardiff, True) 80 | 81 | aligned_a, aligned_b = [], [] 82 | for i, j in reversed(best_path_reversed): 83 | if i == -1: 84 | aligned_a.append(insert_tok) 85 | else: 86 | aligned_a.append(words_a[i]) 87 | if j == -1: 88 | aligned_b.append(insert_tok) 89 | else: 90 | aligned_b.append(words_b[j]) 91 | 92 | return aligned_a, aligned_b, cost 93 | 94 | 95 | def align_texts_ctm(text_a_str, text_b_str, times_a, times_b, durs_a, durs_b, debug, insert_tok): 96 | len_a = len(text_a_str) 97 | len_b = len(text_b_str) 98 | # doing dynamic time warp 99 | text_a_str = [insert_tok] + text_a_str 100 | text_b_str = [insert_tok] + text_b_str 101 | # +1 because of padded start token 102 | summed_cost = np.zeros((len_a + 1, len_b + 1), dtype=np.float64, order="C") 103 | cost = texterrors_align.calc_sum_cost_ctm(summed_cost, text_a_str, text_b_str, 104 | times_a, times_b, durs_a, durs_b) 105 | 106 | if debug: 107 | np.set_printoptions(linewidth=300) 108 | np.savetxt('summedcost', summed_cost, fmt='%.3f', delimiter='\t') 109 | best_path_lst = [] 110 | texterrors_align.get_best_path_ctm(summed_cost, best_path_lst, 111 | text_a_str, text_b_str, times_a, times_b, durs_a, durs_b) 112 | assert len(best_path_lst) % 2 == 0 113 | path = [] 114 | for n in range(0, len(best_path_lst), 2): 115 | i = best_path_lst[n] 116 | j = best_path_lst[n + 1] 117 | path.append((i, j)) 118 | 119 | # convert hook (up left or left up) transitions to diag, not important. 120 | # -1 because of padding tokens, i = 1 because first is given 121 | newpath = [path[0]] 122 | i = 1 123 | lasttpl = path[0] 124 | while i < len(path) - 1: 125 | tpl = path[i] 126 | nexttpl = path[i + 1] 127 | if ( 128 | lasttpl[0] - 1 == nexttpl[0] and lasttpl[1] - 1 == nexttpl[1] 129 | ): # minus because reversed 130 | pass 131 | else: 132 | newpath.append(tpl) 133 | i += 1 134 | lasttpl = tpl 135 | path = newpath 136 | 137 | aligned_a, aligned_b = [], [] 138 | lasti, lastj = -1, -1 139 | for i, j in list(reversed(path)): 140 | # print(text_a[i], text_b[i], file=sys.stderr) 141 | if i != lasti: 142 | aligned_a.append(text_a_str[i]) 143 | else: 144 | aligned_a.append(insert_tok) 145 | if j != lastj: 146 | aligned_b.append(text_b_str[j]) 147 | else: 148 | aligned_b.append(insert_tok) 149 | lasti, lastj = i, j 150 | 151 | return aligned_a, aligned_b, cost 152 | 153 | 154 | def align_texts(text_a, text_b, debug, insert_tok='', use_chardiff=True): 155 | 156 | assert isinstance(text_a, StringVector) and isinstance(text_b, StringVector), 'Input types should be of type StringVector!' 157 | 158 | aligned_a, aligned_b, cost = _align_texts(text_a, text_b, use_chardiff, 159 | debug=debug, insert_tok=insert_tok) 160 | 161 | if debug: 162 | print(aligned_a) 163 | print(aligned_b) 164 | return aligned_a, aligned_b, cost 165 | 166 | 167 | def get_overlap(refw, hypw): 168 | # 0 if match, -1 if hyp before, 1 if after 169 | if hypw[1] < refw[1]: 170 | neg_offset = refw[1] - hypw[1] 171 | if neg_offset < hypw[2] * 0.5: 172 | return 0 173 | else: 174 | return -1 175 | else: 176 | pos_offset = hypw[1] - refw[1] 177 | if pos_offset < hypw[2] * 0.5: 178 | return 0 179 | else: 180 | return 1 181 | 182 | 183 | def get_oov_cer(ref_aligned, hyp_aligned, oov_set): 184 | # https://arxiv.org/abs/2107.08091 185 | assert len(ref_aligned) == len(hyp_aligned) 186 | oov_count_denom = 0 187 | oov_count_error = 0 188 | for i, ref_w in enumerate(ref_aligned): 189 | if ref_w in oov_set: 190 | oov_count_denom += len(ref_w) 191 | startidx = i - 1 if i - 1 >= 0 else 0 192 | hyp_w = '' 193 | for idx in range(startidx, startidx + 2): 194 | if idx != i: 195 | if idx > len(ref_aligned) - 1 or ref_aligned[idx] != '': 196 | continue 197 | if idx < i: 198 | hyp_w += hyp_aligned[idx] + ' ' 199 | else: 200 | hyp_w += ' ' + hyp_aligned[idx] 201 | else: 202 | hyp_w += hyp_aligned[idx] 203 | hyp_w = hyp_w.strip() 204 | hyp_w = hyp_w.replace('', '') 205 | d = texterrors_align.lev_distance_str(ref_w, hyp_w) 206 | oov_count_error += d 207 | return oov_count_error, oov_count_denom 208 | 209 | 210 | @dataclass 211 | class Utt: 212 | uid: str 213 | words: StringVector 214 | times: list = None 215 | durs: list = None 216 | 217 | def __len__(self): 218 | return len(self.words) 219 | 220 | 221 | def read_ref_file(ref_f, isark): 222 | ref_utts = {} 223 | with open(ref_f) as fh: 224 | for i, line in enumerate(fh): 225 | if isark: 226 | utt, *words = line.split() 227 | assert utt not in ref_utts, 'There are repeated utterances in reference file! Exiting' 228 | if CPP_WORDS_CONTAINER: 229 | words = StringVector(words) 230 | ref_utts[utt] = Utt(utt, words) 231 | else: 232 | words = line.split() 233 | i = str(i) 234 | if CPP_WORDS_CONTAINER: 235 | words = StringVector(words) 236 | ref_utts[i] = Utt(i, words) 237 | return ref_utts 238 | 239 | 240 | def read_hyp_file(hyp_f, isark, oracle_wer): 241 | hyp_utts = {} if not oracle_wer else defaultdict(list) 242 | with open(hyp_f) as fh: 243 | for i, line in enumerate(fh): 244 | if isark: 245 | utt, *words = line.split() 246 | words = [w for w in words if w != OOV_SYM] 247 | if CPP_WORDS_CONTAINER: 248 | words = StringVector(words) 249 | if not oracle_wer: 250 | hyp_utts[utt] = Utt(utt, words) 251 | else: 252 | hyp_utts[utt].append(Utt(utt, words)) 253 | else: 254 | words = line.split() 255 | i = str(i) 256 | words = [w for w in words if w != OOV_SYM] 257 | if CPP_WORDS_CONTAINER: 258 | words = StringVector(words) 259 | hyp_utts[i] = Utt(i, words) 260 | return hyp_utts 261 | 262 | 263 | def read_ctm_file(f): 264 | """ Assumes first field is utt and last three fields are word, time, duration """ 265 | utt_to_wordtimes = defaultdict(list) 266 | with open(f) as fh: 267 | for line in fh: 268 | utt, *_, time, dur, word = line.split() 269 | time = float(time) 270 | dur = float(dur) 271 | utt_to_wordtimes[utt].append((word, time, dur,)) 272 | utts = {} 273 | for utt, wordtimes in utt_to_wordtimes.items(): 274 | words = [] 275 | times = [] 276 | durs = [] 277 | for e in wordtimes: 278 | words.append(e[0]), times.append(e[1]), durs.append([2]) 279 | utts[utt] = Utt(utt, StringVector(words), times, durs) 280 | return utt_to_wordtimes 281 | 282 | 283 | @dataclass 284 | class LineElement: 285 | words: Tuple[str] 286 | 287 | 288 | class MultiLine: 289 | def __init__(self, terminal_width, num_lines): 290 | self.line_elements = [] 291 | self.terminal_width = terminal_width 292 | self.num_lines = num_lines 293 | 294 | def add_lineelement(self, *words): 295 | self.line_elements.append(words) 296 | 297 | def __len__(self): 298 | return len(self.line_elements) 299 | 300 | def __getitem__(self, item): 301 | return self.line_elements[item] 302 | 303 | @staticmethod 304 | def construct(*lines): 305 | joined_lines = [] 306 | for line in lines: 307 | joined_lines.append(' '.join(line)) 308 | return joined_lines 309 | 310 | def __repr__(self): 311 | elems = [] 312 | for le in self.line_elements: 313 | elems.append('|'.join(w for w in le)) 314 | return '\t'.join(elems) 315 | 316 | def iter_construct(self): 317 | index = 0 318 | lines = [[] for _ in range(self.num_lines)] 319 | written_len = 0 320 | while index < len(self.line_elements): 321 | le = self.line_elements[index] 322 | lengths = [len(_remove_color(w)) for w in le] 323 | padded_len = max(*lengths) 324 | if written_len + padded_len > self.terminal_width: 325 | joined_lines = self.construct(*lines) 326 | lines = [[] for _ in range(self.num_lines)] 327 | yield joined_lines 328 | written_len = 0 329 | written_len += padded_len + 1 # +1 because space will be added 330 | words = le 331 | for i, line in enumerate(lines): 332 | word = words[i] 333 | wordlen = padded_len 334 | wordlen += get_color_lengthoffset(word) 335 | line.append(f'{word:^{wordlen}}') 336 | 337 | index += 1 338 | joined_lines = self.construct(*lines) 339 | yield joined_lines 340 | 341 | 342 | @dataclass 343 | class ErrorStats: 344 | total_cost: int = 0 345 | total_count: int = 0 346 | utts: List[str] = field(default_factory=list) 347 | utt_wrong: int = 0 348 | ins: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) 349 | dels: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) 350 | subs: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) 351 | char_error_count: int = 0 352 | char_count: int = 0 353 | oov_count_error: int = 0 354 | oov_count_denom: int = 0 355 | oov_word_error: int = 0 356 | oov_word_count: int = 0 357 | keywords_predicted: int = 0 358 | keywords_output: int = 0 359 | keywords_count: int = 0 360 | word_counts: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) 361 | 362 | 363 | def read_files(ref_f, hyp_f, isark, isctm, keywords_f, utt_group_map_f, oracle_wer): 364 | if not isctm: 365 | ref_utts = read_ref_file(ref_f, isark) 366 | hyp_utts = read_hyp_file(hyp_f, isark, oracle_wer) 367 | else: 368 | ref_utts = read_ctm_file(ref_f) 369 | hyp_utts = read_ctm_file(hyp_f) 370 | 371 | keywords = set() 372 | if keywords_f: 373 | for line in open(keywords_f): 374 | assert len(line.split()) == 1, 'A keyword must be a single word!' 375 | keywords.add(line.strip()) 376 | 377 | utt_group_map = {} 378 | if utt_group_map_f: 379 | for line in open(utt_group_map_f): 380 | uttid, group = line.split(maxsplit=1) 381 | group = group.strip() 382 | utt_group_map[uttid] = group 383 | 384 | return ref_utts, hyp_utts, keywords, utt_group_map 385 | 386 | 387 | def print_detailed_stats(fh, ins, dels, subs, num_top_errors, freq_sort, word_counts): 388 | fh.write(f'\nInsertions:\n') 389 | for v, c in sorted(ins.items(), key=lambda x: x[1], reverse=True)[:num_top_errors]: 390 | fh.write(f'{v}\t{c}\n') 391 | fh.write('\n') 392 | fh.write(f'Deletions (second number is word count total):\n') 393 | for v, c in sorted(dels.items(), key=lambda x: (x[1] if not freq_sort else x[1] / word_counts[x[0]]), 394 | reverse=True)[:num_top_errors]: 395 | fh.write(f'{v}\t{c}\t{word_counts[v]}\n') 396 | fh.write('\n') 397 | fh.write(f'Substitutions (reference>hypothesis, second number is reference word count total):\n') 398 | for v, c in sorted(subs.items(), 399 | key=lambda x: (x[1] if not freq_sort else (x[1] / word_counts[x[0].split('>')[0].strip()], x[1],)), 400 | reverse=True)[:num_top_errors]: 401 | ref_w = v.split('>')[0].strip() 402 | fh.write(f'{v}\t{c}\t{word_counts[ref_w]}\n') 403 | 404 | 405 | def process_lines(ref_utts, hyp_utts, debug, use_chardiff, isctm, skip_detailed, 406 | terminal_width, oracle_wer, keywords, oov_set, cer, utt_group_map, 407 | group_stats, nocolor, insert_tok, fullprint=False, suppress_warnings=False): 408 | 409 | error_stats = ErrorStats() 410 | dct_char = {insert_tok: 0, 0: insert_tok} 411 | multilines = [] 412 | for utt in ref_utts.keys(): 413 | logger.debug('%s' % utt) 414 | ref = ref_utts[utt] 415 | 416 | is_empty_reference = not len(ref.words) 417 | 418 | if oracle_wer: 419 | hyps = hyp_utts[utt] 420 | costs = [] 421 | for hyp in hyps: 422 | _, _, cost = align_texts(ref.words, hyp.words, debug, use_chardiff=use_chardiff) 423 | costs.append(cost) 424 | error_stats.total_cost += min(costs) 425 | error_stats.total_count += len(ref) 426 | continue 427 | 428 | hyp = hyp_utts.get(utt) 429 | if hyp is None: 430 | logger.warning(f'Missing hypothesis for utterance: {utt}') 431 | continue 432 | error_stats.utts.append(utt) 433 | logger.debug('ref: %s' % ref.words) 434 | logger.debug('hyp: %s' % hyp.words) 435 | 436 | if not isctm: 437 | ref_aligned, hyp_aligned, cost = align_texts(ref.words, hyp.words, debug, use_chardiff=use_chardiff) 438 | else: 439 | ref_aligned, hyp_aligned, cost = align_texts_ctm(ref.words, hyp.words, ref.times, 440 | hyp.times, ref.durs, hyp.durs, debug, insert_tok) 441 | error_stats.total_cost += cost 442 | 443 | # Counting errors 444 | error_count = 0 445 | ref_word_count = 0 446 | 447 | double_line = MultiLine(terminal_width, 2) 448 | for i, (ref_w, hyp_w,) in enumerate(zip(ref_aligned, hyp_aligned)): 449 | if ref_w in keywords: 450 | error_stats.keywords_count += 1 451 | if hyp_w in keywords: 452 | error_stats.keywords_output += 1 453 | if ref_w in oov_set: 454 | error_stats.oov_word_count += 1 455 | 456 | if ref_w == hyp_w: 457 | if hyp_w in keywords: 458 | error_stats.keywords_predicted += 1 459 | if not fullprint: 460 | double_line.add_lineelement(ref_w, '') 461 | else: 462 | double_line.add_lineelement(ref_w, ref_w) 463 | error_stats.word_counts[ref_w] += 1 464 | ref_word_count += 1 465 | else: 466 | error_count += 1 467 | if ref_w in oov_set: 468 | error_stats.oov_word_error += 1 469 | if ref_w == '': 470 | if fullprint: 471 | double_line.add_lineelement('', hyp_w) 472 | elif not nocolor: 473 | double_line.add_lineelement('', colored(hyp_w, 'red', force_color=True)) 474 | else: 475 | hyp_w_upper = hyp_w.upper() 476 | double_line.add_lineelement('*', hyp_w_upper) 477 | error_stats.ins[hyp_w] += 1 478 | elif hyp_w == '': 479 | if fullprint: 480 | double_line.add_lineelement(ref_w, '') 481 | elif not nocolor: 482 | double_line.add_lineelement(colored(ref_w, 'green', force_color=True), '') 483 | else: 484 | ref_w_upper = ref_w.upper() 485 | double_line.add_lineelement(ref_w_upper, '*') 486 | ref_word_count += 1 487 | error_stats.dels[ref_w] += 1 488 | error_stats.word_counts[ref_w] += 1 489 | else: 490 | ref_word_count += 1 491 | key = f'{ref_w}>{hyp_w}' 492 | if fullprint: 493 | double_line.add_lineelement(ref_w, hyp_w,) 494 | elif not nocolor: 495 | double_line.add_lineelement(colored(ref_w, 'green', force_color=True), colored(hyp_w, 'red', force_color=True)) 496 | else: 497 | ref_w_upper = ref_w.upper() 498 | hyp_w_upper = hyp_w.upper() 499 | double_line.add_lineelement(ref_w_upper, hyp_w_upper) 500 | error_stats.subs[key] += 1 501 | error_stats.word_counts[ref_w] += 1 502 | #breakpoint() 503 | error_stats.total_count += ref_word_count 504 | #breakpoint() 505 | if not skip_detailed: 506 | multilines.append(double_line) 507 | 508 | if utt_group_map: 509 | group = utt_group_map[utt] 510 | group_stats[group]['count'] += ref_word_count 511 | group_stats[group]['errors'] += error_count 512 | 513 | if error_count: 514 | error_stats.utt_wrong += 1 515 | 516 | if cer: # Calculate CER 517 | def convert_to_char_list(lst): 518 | new = [] 519 | for i, word in enumerate(lst): 520 | new.extend(list(word)) 521 | if i != len(lst) - 1: 522 | new.append(' ') 523 | return new 524 | 525 | char_ref = convert_to_char_list(ref.words) 526 | char_hyp = convert_to_char_list(hyp.words) 527 | 528 | ref_int, hyp_int = convert_to_int(char_ref, char_hyp, dct_char) 529 | error_stats.char_error_count += texterrors_align.lev_distance(ref_int, hyp_int) 530 | error_stats.char_count += len(ref_int) 531 | 532 | if is_empty_reference: 533 | continue 534 | 535 | if oov_set: # Get OOV CER 536 | err, cnt = get_oov_cer(ref_aligned, hyp_aligned, oov_set) 537 | error_stats.oov_count_error += err 538 | error_stats.oov_count_denom += cnt 539 | # if not skip_detailed: 540 | # assert len(multilines) == len(error_stats.utts) 541 | return multilines, error_stats 542 | 543 | 544 | 545 | def _remove_color(word): 546 | return re.sub(br'\x1b\[[0-9]{2}m([\p{L}\p{P}]+)\x1b\[0m', br'\1', word.encode()).decode() 547 | 548 | 549 | def get_color_lengthoffset(word): 550 | if word.count('\x1b') == 2: 551 | return 9 552 | if word.count('\x1b') == 4: 553 | return 18 554 | return 0 555 | 556 | 557 | def _merge_multilines(multilines_a, multilines_b, terminal_width, usecolor): 558 | multilines = [] 559 | # print(multilines_a) 560 | # print() 561 | # print(multilines_b) 562 | # print() 563 | for multiline_a, multiline_b in zip(multilines_a, multilines_b): 564 | multiline = MultiLine(terminal_width, 3) 565 | idx_a, idx_b = 0, 0 566 | while idx_a < len(multiline_a) and idx_b < len(multiline_b): 567 | le_a = multiline_a[idx_a] 568 | le_b = multiline_b[idx_b] 569 | hyp_worda = le_a[1] 570 | hyp_wordb = le_b[1] 571 | 572 | if le_a[0] == le_b[0]: # ref words match 573 | refword = le_a[0] 574 | if refword == hyp_worda and refword == hyp_wordb: # everything correct 575 | multiline.add_lineelement(refword, hyp_worda, hyp_wordb) 576 | elif not refword: # double insertion 577 | if usecolor: 578 | hyp_worda = colored(hyp_worda, 'red', force_color=True) 579 | hyp_wordb = colored(hyp_wordb, 'red', force_color=True) 580 | 581 | multiline.add_lineelement('', hyp_worda, hyp_wordb) 582 | else: # hyp1 and/or hyp2 are wrong 583 | if usecolor: 584 | if refword != hyp_worda and hyp_worda: 585 | hyp_worda = colored(hyp_worda, 'red', force_color=True) 586 | 587 | if refword != hyp_wordb and hyp_wordb: 588 | hyp_wordb = colored(hyp_wordb, 'red', force_color=True) 589 | refword = colored(refword, 'green', force_color=True) 590 | 591 | multiline.add_lineelement(refword, hyp_worda, hyp_wordb) 592 | idx_a += 1 593 | idx_b += 1 594 | elif not le_a[0]: # ins 595 | if usecolor: 596 | hyp_worda = colored(hyp_worda, 'red', force_color=True) 597 | multiline.add_lineelement('', hyp_worda, '') 598 | idx_a += 1 599 | elif not le_b[0]: # ins 600 | if usecolor: 601 | hyp_wordb = colored(hyp_wordb, 'red', force_color=True) 602 | multiline.add_lineelement('', '', hyp_wordb) 603 | idx_b += 1 604 | 605 | else: 606 | logger.warning('Weird case!! found please report') 607 | refword = le_a[0] + '|' + le_b[0] 608 | if usecolor: 609 | refword = colored(refword, 'green', force_color=True) 610 | if le_a[0] != hyp_worda: 611 | hyp_worda = colored(hyp_worda, 'red', force_color=True) 612 | if le_b[0] != hyp_wordb: 613 | hyp_wordb = colored(hyp_wordb, 'red', force_color=True) 614 | multiline.add_lineelement(refword, hyp_worda, hyp_wordb) 615 | idx_a += 1 616 | idx_b += 1 617 | 618 | while idx_a < len(multiline_a): 619 | assert idx_b == len(multiline_b) 620 | le_a = multiline_a[idx_a] 621 | multiline.add_lineelement(le_a[0], colored(le_a[1], 'red', force_color=True), '') 622 | idx_a += 1 623 | while idx_b < len(multiline_b): 624 | assert idx_a == len(multiline_a) 625 | le_b = multiline_b[idx_b] 626 | multiline.add_lineelement(le_b[0], '', colored(le_b[1], 'red', force_color=True)) 627 | idx_b += 1 628 | #print(multiline) 629 | multilines.append(multiline) 630 | return multilines 631 | 632 | 633 | def process_multiple_outputs(ref_utts, hypa_utts, hypb_utts, fh, num_top_errors, 634 | use_chardiff, freq_sort, ref_file, file_a, file_b, terminal_width=None, usecolor=False): 635 | if terminal_width is None: 636 | terminal_width, _ = shutil.get_terminal_size() 637 | terminal_width = 120 if terminal_width >= 120 else terminal_width 638 | 639 | multilines_ref_hypa, error_stats_ref_hypa = process_lines(ref_utts, hypa_utts, False, use_chardiff, False, 640 | False, terminal_width, False, [], [], False, 641 | None, None, nocolor=False, insert_tok='',fullprint=True) 642 | multilines_ref_hypb, error_stats_ref_hypb = process_lines(ref_utts, hypb_utts, False, use_chardiff, False, 643 | False, terminal_width, False, [], [], False, 644 | None, None, nocolor=False, insert_tok='', fullprint=True) 645 | _, error_stats_hypa_hypb = process_lines(hypa_utts, hypb_utts, False, use_chardiff, False, 646 | True, terminal_width, False, [], [], False, 647 | None, None, nocolor=True, insert_tok='') 648 | 649 | merged_multiline = _merge_multilines(multilines_ref_hypa, multilines_ref_hypb, 650 | terminal_width, usecolor) 651 | fh.write(f'Per utt details, order is \"{ref_file}\", \"{file_a}\", \"{file_b}\":\n') 652 | for utt, multiline in zip(error_stats_ref_hypa.utts, merged_multiline): 653 | fh.write(f'{utt}\n') 654 | for lines in multiline.iter_construct(): 655 | for line in lines: 656 | fh.write(f'{line}\n') 657 | 658 | # Outputting metrics from gathered statistics. 659 | ins_count = sum(error_stats_ref_hypa.ins.values()) 660 | del_count = sum(error_stats_ref_hypa.dels.values()) 661 | sub_count = sum(error_stats_ref_hypa.subs.values()) 662 | wer = (ins_count + del_count + sub_count) / float(error_stats_ref_hypa.total_count) 663 | fh.write(f'\nResults with file {file_a}' 664 | f'\nWER: {100. * wer:.1f} (ins {ins_count}, del {del_count}, sub {sub_count} / {error_stats_ref_hypa.total_count})' 665 | f'\nSER: {100. * error_stats_ref_hypa.utt_wrong / len(error_stats_ref_hypa.utts):.1f}\n') 666 | 667 | print_detailed_stats(fh, error_stats_ref_hypa.ins, error_stats_ref_hypa.dels, 668 | error_stats_ref_hypa.subs, num_top_errors, freq_sort, 669 | error_stats_ref_hypa.word_counts) 670 | fh.write(f'---\n') 671 | 672 | ins_count = sum(error_stats_ref_hypb.ins.values()) 673 | del_count = sum(error_stats_ref_hypb.dels.values()) 674 | sub_count = sum(error_stats_ref_hypb.subs.values()) 675 | wer = (ins_count + del_count + sub_count) / float(error_stats_ref_hypb.total_count) 676 | fh.write(f'\nResults with file {file_b}' 677 | f'\nWER: {100. * wer:.1f} (ins {ins_count}, del {del_count}, sub {sub_count} / {error_stats_ref_hypb.total_count})' 678 | f'\nSER: {100. * error_stats_ref_hypb.utt_wrong / len(error_stats_ref_hypb.utts):.1f}\n') 679 | 680 | print_detailed_stats(fh, error_stats_ref_hypb.ins, error_stats_ref_hypb.dels, 681 | error_stats_ref_hypb.subs, num_top_errors, freq_sort, 682 | error_stats_ref_hypb.word_counts) 683 | fh.write(f'---\n') 684 | 685 | fh.write(f'\nDifference between outputs:\n') 686 | print_detailed_stats(fh, error_stats_hypa_hypb.ins, error_stats_hypa_hypb.dels, 687 | error_stats_hypa_hypb.subs, num_top_errors, freq_sort, 688 | error_stats_hypa_hypb.word_counts) 689 | 690 | 691 | def process_output(ref_utts, hyp_utts, fh, ref_file, hyp_file, cer=False, num_top_errors=10, oov_set=None, debug=False, 692 | use_chardiff=True, isctm=False, skip_detailed=False, 693 | keywords=None, utt_group_map=None, oracle_wer=False, 694 | freq_sort=False, nocolor=False, insert_tok='', terminal_width=None): 695 | 696 | if terminal_width is None: 697 | terminal_width, _ = shutil.get_terminal_size() 698 | terminal_width = 120 if terminal_width >= 120 else terminal_width 699 | 700 | if oov_set is None: 701 | oov_set = set() 702 | if keywords is None: 703 | keywords = set() 704 | if utt_group_map is None: 705 | utt_group_map = {} 706 | 707 | group_stats = {} 708 | groups = set(utt_group_map.values()) 709 | for group in groups: 710 | group_stats[group] = {} 711 | group_stats[group]['count'] = 0 712 | group_stats[group]['errors'] = 0 713 | 714 | multilines, error_stats = process_lines(ref_utts, hyp_utts, debug, use_chardiff, isctm, skip_detailed, 715 | terminal_width, oracle_wer, keywords, oov_set, cer, 716 | utt_group_map, group_stats, nocolor, insert_tok) 717 | 718 | if not skip_detailed and not oracle_wer: 719 | if nocolor: 720 | fh.write(f'\"{ref_file}\" is treated as reference, \"{hyp_file}\" as hypothesis. Errors are capitalized.\n') 721 | else: 722 | fh.write(f'\"{ref_file}\" is treated as reference (white and green), \"{hyp_file}\" as hypothesis (white and red).\n') 723 | fh.write(f'Per utt details:\n') 724 | for utt, multiline in zip(error_stats.utts, multilines): 725 | fh.write(f'{utt}\n') 726 | for upper_line, lower_line in multiline.iter_construct(): 727 | fh.write(f'{upper_line}\n') 728 | fh.write(f'{lower_line}\n') 729 | 730 | if not use_chardiff and not oracle_wer: 731 | s = sum(v for v in chain(error_stats.ins.values(), error_stats.dels.values(), error_stats.subs.values())) 732 | assert s == error_stats.total_cost, f'{s} {error_stats.total_cost}' 733 | if oracle_wer: 734 | fh.write(f'Oracle WER: {error_stats.total_cost / error_stats.total_count}\n') 735 | return 736 | 737 | # Outputting metrics from gathered statistics. 738 | ins_count = sum(error_stats.ins.values()) 739 | del_count = sum(error_stats.dels.values()) 740 | sub_count = sum(error_stats.subs.values()) 741 | wer = (ins_count + del_count + sub_count) / float(error_stats.total_count) 742 | if not skip_detailed: 743 | fh.write('\n') 744 | fh.write(f'WER: {100.*wer:.1f} (ins {ins_count}, del {del_count}, sub {sub_count} / {error_stats.total_count})' 745 | f'\nSER: {100.*error_stats.utt_wrong / len(error_stats.utts):.1f}\n') 746 | 747 | if cer: 748 | cer = error_stats.char_error_count / float(error_stats.char_count) 749 | fh.write(f'CER: {100.*cer:.1f} ({error_stats.char_error_count} / {error_stats.char_count})\n') 750 | if oov_set: 751 | if error_stats.oov_word_count: 752 | fh.write(f'OOV CER: {100.*error_stats.oov_count_error / error_stats.oov_count_denom:.1f}\n') 753 | fh.write(f'OOV WER: {100.*error_stats.oov_word_error / error_stats.oov_word_count:.1f}\n') 754 | else: 755 | logger.error('None of the words in the OOV list file were found in the reference!') 756 | if keywords: 757 | fh.write(f'Keyword results - recall {error_stats.keywords_predicted / error_stats.keywords_count if error_stats.keywords_count else -1:.2f} ' 758 | f'- precision {error_stats.keywords_predicted / error_stats.keywords_output if error_stats.keywords_output else -1:.2f}\n') 759 | if utt_group_map: 760 | fh.write('Group WERs:\n') 761 | for group, stats in group_stats.items(): 762 | wer = 100. * (stats['errors'] / float(stats['count'])) 763 | fh.write(f'{group}\t{wer:.1f}\n') 764 | fh.write('\n') 765 | 766 | if not skip_detailed: 767 | print_detailed_stats(fh, error_stats.ins, error_stats.dels, error_stats.subs, num_top_errors, freq_sort, 768 | error_stats.word_counts) 769 | 770 | 771 | def main( 772 | ref_file: 'Reference text', 773 | hyp_file: 'Hypothesis text', 774 | outf: 'Optional output file' = '', 775 | oov_list_f: ('List of OOVs', 'option', None) = '', 776 | isark: ('Text files start with utterance ID.', 'flag')=False, 777 | isctm: ('Text files start with utterance ID and end with word, time, duration', 'flag')=False, 778 | use_chardiff: ('Use character lev distance for better alignment in exchange for slightly higher WER.', 'flag') = False, 779 | cer: ('Calculate CER', 'flag')=False, 780 | debug: ('Print debug messages, will write cost matrix to summedcost.', 'flag', 'd')=False, 781 | skip_detailed: ('No per utterance output', 'flag', 's') = False, 782 | keywords_f: ('Will filter out non keyword reference words.', 'option', None) = '', 783 | freq_sort: ('Turn on sorting del/sub errors by frequency (default is by count).', 'flag', None) = False, 784 | oracle_wer: ('Hyp file should have multiple hypothesis per utterance, lowest edit distance will be used for WER.', 'flag', None) = False, 785 | utt_group_map_f: ('Should be a file which maps uttids to group, WER will be output per group.', 'option', '') = '', 786 | usecolor: ('Show detailed output with color (use less -R). Red/white is reference, Green/white model output.', 'flag', 'c')=False, 787 | num_top_errors: ('Number of errors to show per type in detailed output.', 'option')=10, 788 | second_hyp_f: ('Will compare outputs between two hypothesis files.', 'option')='' 789 | ): 790 | 791 | logger.remove() 792 | if debug: 793 | logger.add(sys.stderr, level="DEBUG") 794 | else: 795 | logger.add(sys.stderr, level="INFO") 796 | 797 | if outf: 798 | fh = open(outf, 'w') 799 | else: 800 | fh = sys.stdout 801 | if not second_hyp_f: 802 | if oracle_wer: 803 | assert isark and not isctm 804 | skip_detailed = True 805 | if use_chardiff: 806 | logger.warning(f'You probably would prefer running without `-use_chardiff`, the WER will be slightly better for the cost of a worse alignment') 807 | 808 | oov_set = set() 809 | if oov_list_f: 810 | if not use_chardiff: 811 | logger.warning('Because you are using standard alignment (not `-use_chardiff`) the alignments could be suboptimal\n' 812 | ' which will lead to the OOV-CER being slightly wrong. Use `-use_chardiff` for better alignment, ctm based for the best.') 813 | with open(oov_list_f) as fh_oov: 814 | for line in fh_oov: 815 | oov_set.add(line.split()[0]) # splitting incase line contains another entry (for example count) 816 | 817 | ref_utts, hyp_utts, keywords, utt_group_map = read_files(ref_file, 818 | hyp_file, isark, isctm, keywords_f, utt_group_map_f, oracle_wer) 819 | 820 | process_output(ref_utts, hyp_utts, fh, cer=cer, debug=debug, oov_set=oov_set, 821 | ref_file=ref_file, hyp_file=hyp_file, use_chardiff=use_chardiff, skip_detailed=skip_detailed, 822 | keywords=keywords, utt_group_map=utt_group_map, freq_sort=freq_sort, 823 | isctm=isctm, oracle_wer=oracle_wer, nocolor=not usecolor, num_top_errors=num_top_errors) 824 | else: 825 | ref_utts = read_ref_file(ref_file, isark) 826 | hyp_uttsa = read_hyp_file(hyp_file, isark, False) 827 | hyp_uttsb = read_hyp_file(second_hyp_f, isark, False) 828 | 829 | process_multiple_outputs(ref_utts, hyp_uttsa, hyp_uttsb, fh, num_top_errors, 830 | use_chardiff, freq_sort, ref_file, hyp_file, second_hyp_f, usecolor=usecolor) 831 | 832 | fh.close() 833 | 834 | 835 | def cli(): # entrypoint used in setup.py 836 | plac.call(main) 837 | 838 | 839 | if __name__ == "__main__": 840 | plac.call(main) 841 | --------------------------------------------------------------------------------