├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── suber ├── __init__.py ├── __main__.py ├── concat_input_files.py ├── constants.py ├── data_types.py ├── file_readers │ ├── __init__.py │ ├── file_reader_base.py │ ├── plain_file_reader.py │ └── srt_file_reader.py ├── hyp_to_ref_alignment │ ├── __init__.py │ ├── levenshtein_alignment.py │ └── time_alignment.py ├── lib_levenshtein.py ├── metrics │ ├── __init__.py │ ├── cer.py │ ├── jiwer_interface.py │ ├── length_ratio.py │ ├── lib_ter.py │ ├── sacrebleu_interface.py │ ├── suber.py │ └── suber_statistics.py ├── sentence_segmentation.py ├── tools │ ├── __init__.py │ ├── align_hyp_to_ref.py │ └── srt_to_plain.py └── utilities.py └── tests ├── __init__.py ├── fuzz_levenshtein.py ├── test_cer.py ├── test_file_readers.py ├── test_hyp_to_ref_alignment.py ├── test_jiwer_interface.py ├── test_length_ratio.py ├── test_main.py ├── test_sacrebleu_interface.py ├── test_sentence_segmentation.py ├── test_suber_metric.py ├── test_suber_statistics.py ├── test_tools.py └── utilities.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .history* 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 AppTek 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SubER - Subtitle Edit Rate 2 | 3 | SubER is an automatic, reference-based, segmentation- and timing-aware edit distance metric to measure quality of subtitle files. 4 | For a detailed description of the metric and a human post-editing evaluation we refer to our [IWSLT 2022 paper](https://aclanthology.org/2022.iwslt-1.1.pdf). 5 | In addition to the SubER metric, this scoring tool calculates a wide range of established speech recognition and machine translation metrics (WER, BLEU, TER, chrF) directly on subtitle files. 6 | 7 | ## Installation 8 | ```console 9 | pip install subtitle-edit-rate 10 | ``` 11 | will install the `suber` command line tool. 12 | Alternatively, check out this git repository and run the contained `suber` module with `python -m suber`. 13 | 14 | ## Basic Usage 15 | Currently, we expect subtitle files to come in [SubRip text (SRT)](https://en.wikipedia.org/wiki/SubRip) format. Given a human reference subtitle file `reference.srt` and a hypothesis file `hypothesis.srt` (typically the output of an automatic subtitling system) the SubER score can be calculated by running: 16 | 17 | ```console 18 | $ suber -H hypothesis.srt -R reference.srt 19 | { 20 | "SubER": 19.048 21 | } 22 | ``` 23 | The SubER score is printed to stdout in json format. As SubER is an edit rate, lower scores are better. As a rough rule of thumb from our experience, a score lower than 20(%) is very good quality while a score above 40 to 50(%) is bad. 24 | 25 | Make sure that there is no constant time offset between the timestamps in hypothesis and reference as this will lead to incorrect scores. 26 | Also, note that ``, `` and `` formatting tags are ignored if present in the files. All other formatting must be removed from the files before scoring for accurate results. 27 | 28 | #### Punctuation and Case-Sensitivity 29 | The main SubER metric is computed on normalized text, which means case-insensitive and without taking punctuation into account, as we observe higher correlation with human judgements and post-edit effort in this setting. We provide an implementation of a case-sensitive variant which also uses a tokenizer to take punctuation into account as separate tokens which you can use "at your own risk" or to reassess our findings. For this, add `--metrics SubER-cased` to the command above. Please do not report results using this variant as "SubER" unless explicitly mentioning the punctuation-/case-sensitivity. 30 | 31 | ## Other Metrics 32 | The SubER tool supports computing the following other metrics directly on subtitle files: 33 | 34 | - word error rate (WER) 35 | - bilingual evaluation understudy (BLEU) 36 | - translation edit rate (TER) 37 | - character n-gram F score (chrF) 38 | - character error rate (CER) 39 | 40 | BLEU, TER and chrF calculations are done using [SacreBLEU](https://github.com/mjpost/sacrebleu) with default settings. WER is computed with [JiWER](https://github.com/jitsi/jiwer) on normalized text (lower-cased, punctuation removed). 41 | 42 | __Assuming__ `hypothesis.srt` __and__ `reference.srt` __are parallel__, i.e. they contain the same number of subtitles and the contents of the _n_-th subtitle in both files corresponds to each other, the above-mentioned metrics can be computed by running: 43 | ```console 44 | $ suber -H hypothesis.srt -R reference.srt --metrics WER BLEU TER chrF CER 45 | { 46 | "WER": 23.529, 47 | "BLEU": 39.774, 48 | "TER": 23.529, 49 | "chrF": 68.402, 50 | "CER": 17.857 51 | } 52 | ``` 53 | In this mode, the text from each parallel subtitle pair is considered to be a sentence pair. 54 | 55 | ### Scoring Non-Parallel Subtitle Files 56 | In the general case, subtitle files for the same video can have different numbers of subtitles with different time stamps. All metrics - except SubER - usually require to be calculated on parallel segments. To apply these metrics to general subtitle files, the hypothesis file has to be re-segmented to correspond to the reference subtitles. The SubER tool implements two options: 57 | 58 | - alignment by minimizing Levenshtein distance ([Matusov et al.](https://aclanthology.org/2005.iwslt-1.19.pdf)) 59 | - time alignment method from [Cherry et al.](https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf) 60 | 61 | See our [paper](https://aclanthology.org/2022.iwslt-1.1.pdf) for further details. 62 | 63 | To use the Levenshtein method add an `AS-` prefix to the metric name, e.g.: 64 | ```console 65 | suber -H hypothesis.srt -R reference.srt --metrics AS-BLEU 66 | ``` 67 | The `AS-` prefix terminology is taken from [Matusov et al.](https://aclanthology.org/2005.iwslt-1.19.pdf) and stands for "automatic segmentation". 68 | To use the time-alignment method instead, add a `t-` prefix. This works for all metrics (except for SubER itself which does not require re-segmentation). In particular, we implement `t-BLEU` from [Cherry et al.](https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf). We encode the segmentation method (or lack thereof) in the metric name to explicitly distinguish the different resulting metric scores! 69 | 70 | To inspect the re-segmentation applied to the hypothesis you can use the `align_hyp_to_ref.py` tool (run `python -m suber.tools.align_hyp_to_ref -h` for help). 71 | 72 | In case of Levenshtein alignment, there is also the option to give a plain file as the reference. This can be used to provide sentences instead of subtitles as reference segments (each line will be considered a segment): 73 | 74 | ```console 75 | suber -H hypothesis.srt -R reference.txt --reference-format plain --metrics AS-TER 76 | ``` 77 | 78 | We provide a simple tool to extract sentences from SRT files based on punctuation: 79 | 80 | ```console 81 | python -m suber.tools.srt_to_plain -i reference.srt -o reference.txt --sentence-segmentation 82 | ``` 83 | 84 | It can be used to create the plain sentence-level reference `reference.txt` for the scoring command above. 85 | 86 | ### Scoring Line Breaks as Tokens 87 | The line breaks present in the subtitle files can be included into the text segments to be scored as `` (end of line) and `` (end of block) tokens. For example: 88 | 89 | ``` 90 | 636 91 | 00:50:52,200 -> 00:50:57,120 92 | Ladies and gentlemen, 93 | the dance is about to begin. 94 | ``` 95 | would be represented as 96 | ``` 97 | Ladies and gentlemen, the dance is about to begin. 98 | ``` 99 | To do so, add a `-seg` ("segmentation-aware") postfix to the metric name, e.g. `BLEU-seg`, `AS-TER-seg` or `t-WER-seg`. Character-level metrics (chrF and CER) do not support this as it is not obvious how to count character edits for `` tokens. 100 | 101 | ### TER-br 102 | As a special case, we implement TER-br from [Karakanta et al.](https://aclanthology.org/2020.iwslt-1.26.pdf). It is similar to `TER-seg`, but all (real) words are replaced by a mask token. This would convert the sentence from the example above to: 103 | ``` 104 | 105 | ``` 106 | Note, that also TER-br has variants for computing it on existing parallel segments (`TER-br`) or on re-aligned segments (`AS-TER-br`/`t-TER-br`). Re-segmentation happens before masking. 107 | 108 | ## Contributing 109 | If you run into an issue, have a feature request or have questions about the usage or the implementation of SubER, please do not hesitate to open an issue or a thread under "discussions". Pull requests are welcome too, of course! 110 | 111 | Things I'm already considering to add in future versions: 112 | - support for other subtitling formats than SRT 113 | - a verbose output that explains the SubER score (list of edit operations) 114 | 115 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "subtitle-edit-rate" 7 | version = "0.3.0" 8 | dependencies = [ 9 | "sacrebleu==2.5.1", 10 | "jiwer==4.0.0", 11 | "numpy", 12 | "dataclasses;python_version<'3.7'", 13 | ] 14 | requires-python = ">= 3.6" 15 | authors = [ 16 | {name = "Patrick Wilken", email = "pwilken@apptek.com"}, 17 | ] 18 | maintainers = [ 19 | {name = "Patrick Wilken", email = "pwilken@apptek.com"}, 20 | ] 21 | description = "SubER: a metric for automatic evaluation of subtitle quality" 22 | readme = "README.md" 23 | license = "Apache-2.0" 24 | license-files = ["LICENSE"] 25 | keywords = ["subtitling", "subtitles", "captions", "metric", "evaluation"] 26 | classifiers = [ 27 | "Development Status :: 5 - Production/Stable", 28 | "Intended Audience :: Science/Research", 29 | "Programming Language :: Python :: 3", 30 | "Topic :: Scientific/Engineering", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | ] 33 | 34 | [project.urls] 35 | Homepage = "https://github.com/apptek/SubER" 36 | Issues = "https://github.com/apptek/SubER/issues" 37 | Source = "https://github.com/apptek/SubER" 38 | 39 | [tool.setuptools.packages.find] 40 | include = ["suber*"] 41 | 42 | [project.scripts] 43 | suber = "suber.__main__:main" 44 | -------------------------------------------------------------------------------- /suber/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Subtitle Edit Rate 3 | 4 | A metric for automatic evaluation of subtitle quality. 5 | """ 6 | -------------------------------------------------------------------------------- /suber/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | 6 | from collections import OrderedDict 7 | 8 | from suber.file_readers import read_input_file 9 | from suber.concat_input_files import create_concatenated_segments 10 | from suber.hyp_to_ref_alignment import levenshtein_align_hypothesis_to_reference 11 | from suber.hyp_to_ref_alignment import time_align_hypothesis_to_reference 12 | from suber.metrics.suber import calculate_SubER 13 | from suber.metrics.suber_statistics import SubERStatisticsCollector 14 | from suber.metrics.sacrebleu_interface import calculate_sacrebleu_metric 15 | from suber.metrics.jiwer_interface import calculate_word_error_rate 16 | from suber.metrics.cer import calculate_character_error_rate 17 | from suber.metrics.length_ratio import calculate_length_ratio 18 | 19 | 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser( 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 23 | description="SubER - Subtitle Edit Rate. An automatic, reference-based, segmentation- and timing-aware " 24 | "edit distance metric to measure quality of subtitle files. Basic usage: " 25 | "'python -m suber -H hypothesis.srt -R reference.srt'") 26 | parser.add_argument("-H", "--hypothesis", required=True, nargs="+", 27 | help="The input files to score. Usually just one file, but we support test sets consisting of " 28 | "multiple files.") 29 | parser.add_argument("-R", "--reference", required=True, nargs="+", 30 | help="The reference files. Usually just one file, but we support test sets consisting of " 31 | "multiple files.") 32 | parser.add_argument("-m", "--metrics", nargs="+", default=["SubER"], help="The metrics to compute.") 33 | parser.add_argument("-f", "--hypothesis-format", default="SRT", help="Hypothesis file format, 'SRT' or 'plain'.") 34 | parser.add_argument("-F", "--reference-format", default="SRT", help="Reference file format, 'SRT' or 'plain'.") 35 | parser.add_argument("--suber-statistics", action="store_true", 36 | help="If set, will create an '#info' field in the output containing statistics about the " 37 | "different edit operations used to calculate the SubER score.") 38 | 39 | return parser.parse_args() 40 | 41 | 42 | def main(): 43 | args = parse_arguments() 44 | 45 | check_metrics(args.metrics) 46 | check_file_formats(args.hypothesis_format, args.reference_format, args.metrics) 47 | 48 | # A "segment" is a subtitle in case of SRT file input, or a line of text in case of plain input. 49 | if len(args.hypothesis) == 1 and len(args.reference) == 1: 50 | hypothesis_segments = read_input_file(args.hypothesis[0], file_format=args.hypothesis_format) 51 | reference_segments = read_input_file(args.reference[0], file_format=args.reference_format) 52 | else: 53 | hypothesis_segments, reference_segments = create_concatenated_segments( 54 | args.hypothesis, args.reference, args.hypothesis_format, args.reference_format) 55 | 56 | # Aligned hypotheses, either by Levenshtein distance or timing, are only needed by some metrics so we create them 57 | # lazily here. 58 | levenshtein_aligned_hypothesis_segments = None 59 | time_aligned_hypothesis_segments = None 60 | 61 | results = OrderedDict() 62 | additional_outputs = OrderedDict() 63 | 64 | for metric in args.metrics: 65 | if metric in results: 66 | continue # specified multiple times by the user 67 | 68 | if metric == "length_ratio": 69 | results[metric] = calculate_length_ratio(hypothesis=hypothesis_segments, reference=reference_segments) 70 | continue 71 | 72 | # When using existing parallel segments there will always be a word match in the end, don't count it. 73 | # On the other hand, if hypothesis gets aligned to reference a match is not guaranteed, so count it. 74 | score_break_at_segment_end = False 75 | 76 | full_metric_name = metric 77 | hypothesis_segments_to_use = hypothesis_segments 78 | 79 | if metric.startswith("AS-"): 80 | # "AS" stands for automatic segmentation, in particular re-segmentation of the hypothesis using 81 | # the Levenshtein alignment to the reference. 82 | # AS-WER and AS-BLEU were introduced by Matusov et al. https://aclanthology.org/2005.iwslt-1.19.pdf 83 | if levenshtein_aligned_hypothesis_segments is None: 84 | levenshtein_aligned_hypothesis_segments = levenshtein_align_hypothesis_to_reference( 85 | hypothesis=hypothesis_segments, reference=reference_segments) 86 | 87 | hypothesis_segments_to_use = levenshtein_aligned_hypothesis_segments 88 | metric = metric[len("AS-"):] 89 | score_break_at_segment_end = True 90 | 91 | elif metric.startswith("t-"): 92 | # "t" stands for timed. Subtitle timings will be used to re-segment the hypothesis to match the reference 93 | # segments. t-BLEU was introduced by Cherry et al. 94 | # https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf 95 | if time_aligned_hypothesis_segments is None: 96 | time_aligned_hypothesis_segments = time_align_hypothesis_to_reference( 97 | hypothesis=hypothesis_segments, reference=reference_segments) 98 | 99 | hypothesis_segments_to_use = time_aligned_hypothesis_segments 100 | metric = metric[len("t-"):] 101 | score_break_at_segment_end = True 102 | 103 | elif not metric.startswith("SubER") and len(hypothesis_segments_to_use) != len(reference_segments): 104 | raise ValueError(f"Metric '{metric}' assumes same number of segments in hypothesis and reference, but got " 105 | f"{len(hypothesis_segments)} hypothesis and {len(reference_segments)} " 106 | f"reference segments.") 107 | 108 | if metric.startswith("SubER"): 109 | statistics_collector = SubERStatisticsCollector() if args.suber_statistics else None 110 | 111 | metric_score = calculate_SubER( 112 | hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, 113 | statistics_collector=statistics_collector) 114 | 115 | if statistics_collector: 116 | additional_outputs[full_metric_name] = statistics_collector.get_statistics() 117 | 118 | elif metric.startswith("WER"): 119 | metric_score = calculate_word_error_rate( 120 | hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, 121 | score_break_at_segment_end=score_break_at_segment_end) 122 | 123 | elif metric.startswith("CER"): 124 | metric_score = calculate_character_error_rate( 125 | hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric) 126 | 127 | else: 128 | metric_score = calculate_sacrebleu_metric( 129 | hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, 130 | score_break_at_segment_end=score_break_at_segment_end) 131 | 132 | results[full_metric_name] = metric_score 133 | 134 | if additional_outputs: 135 | results["#info"] = additional_outputs 136 | 137 | json_results = json.dumps(results, indent=4) 138 | print(json_results) 139 | 140 | 141 | def check_metrics(metrics): 142 | allowed_metrics = { 143 | # Our proposed metric: 144 | "SubER", "SubER-cased", 145 | # Established ASR and MT metrics, requiring aligned hypothesis-references segments: 146 | "WER", "CER", "BLEU", "TER", "chrF", 147 | # Cased and punctuated variants of the above: 148 | "WER-cased", "CER-cased", 149 | # Segmentation-aware variants of the above that include line breaks as tokens: 150 | "WER-seg", "BLEU-seg", "TER-seg", 151 | # Same as "TER-seg" but all words replaced by a mask token, 152 | # proposed by Karakanta et al. https://aclanthology.org/2020.iwslt-1.26.pdf 153 | "TER-br", 154 | # With an "AS-" prefix, the metric is computed after Levenshtein alignment of hypothesis and reference: 155 | "AS-WER", "AS-CER", "AS-BLEU", "AS-TER", "AS-chrF", "AS-WER-cased", "AS-CER-cased", "AS-WER-seg", 156 | "AS-BLEU-seg", "AS-TER-seg", "AS-TER-br", 157 | # With an "t-" prefix, the metric is computed after time alignment of hypothesis and reference: 158 | "t-WER", "t-CER", "t-BLEU", "t-TER", "t-chrF", "t-WER-cased", "t-CER-cased", "t-WER-seg", "t-BLEU-seg", 159 | "t-TER-seg", "t-TER-br", 160 | # Hypothesis to reference length ratio in terms of number of tokens. 161 | "length_ratio"} 162 | 163 | invalid_metrics = list(sorted(set(metrics) - allowed_metrics)) 164 | if invalid_metrics: 165 | raise ValueError(f"Invalid metric(s): {' '.join(invalid_metrics)}") 166 | 167 | 168 | def check_file_formats(hypothesis_format, reference_format, metrics): 169 | is_plain_input = (hypothesis_format == "plain" or reference_format == "plain") 170 | for metric in metrics: 171 | if ((metric == "SubER" or metric.startswith("t-")) and is_plain_input): 172 | raise ValueError(f"Metric '{metric}' requires timing information and can only be computed on SRT " 173 | f"files (both hypothesis and reference).") 174 | 175 | 176 | if __name__ == "__main__": 177 | main() 178 | -------------------------------------------------------------------------------- /suber/concat_input_files.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from suber.file_readers import read_input_file 4 | from suber.data_types import Segment, Subtitle 5 | 6 | 7 | def create_concatenated_segments(hypothesis_files: List[str], reference_files: List[str], hypothesis_format="SRT", 8 | reference_format="SRT") -> Tuple[List[Segment], List[Segment]]: 9 | """ 10 | Reads all pairs of hypothesis and reference files and creates two concatenated lists containing all hypothesis 11 | segments and all reference segments, respectively. This can be used to score test corpora available in form of many 12 | smaller audio / video files without concatenating the files manually. 13 | In case of SRT input the segments are subtitles with timing information. We adjust the subtitle timings such that 14 | all files are placed one after the other on the time axis, which corresponds to concatenating the corresponding 15 | audio / video files. 16 | """ 17 | if len(hypothesis_files) != len(reference_files): 18 | raise ValueError("Number of hypothesis and reference files must match.") 19 | 20 | all_hypothesis_segments = [] 21 | all_reference_segments = [] 22 | 23 | seconds_to_shift = 0.0 24 | total_hypothesis_duration = 0.0 25 | total_reference_duration = 0.0 26 | 27 | for hypothesis_file, reference_file in zip(hypothesis_files, reference_files): 28 | hypothesis_segments = read_input_file(hypothesis_file, file_format=hypothesis_format) 29 | reference_segments = read_input_file(reference_file, file_format=reference_format) 30 | 31 | if hypothesis_segments and isinstance(hypothesis_segments[0], Subtitle): 32 | total_hypothesis_duration = _shift_subtitles_in_time(hypothesis_segments, seconds_to_shift) 33 | if reference_segments and isinstance(reference_segments[0], Subtitle): 34 | total_reference_duration = _shift_subtitles_in_time(reference_segments, seconds_to_shift) 35 | 36 | seconds_to_shift = max(total_hypothesis_duration, total_reference_duration) 37 | 38 | all_hypothesis_segments += hypothesis_segments 39 | all_reference_segments += reference_segments 40 | 41 | return all_hypothesis_segments, all_reference_segments 42 | 43 | 44 | def _shift_subtitles_in_time(subtitles: List[Subtitle], seconds) -> float: 45 | """ 46 | Returns new total duration after shift. 47 | """ 48 | 49 | for subtitle in subtitles: 50 | _shift_subtitle_in_time(subtitle, seconds) 51 | 52 | # There might be audio / video left after the last subtitle end time, but taking this into account is not necessary 53 | # for metric calculation. 54 | # We add an epsilon to make sure that subtitles from different files are not counted as overlapping. 55 | return subtitles[-1].end_time + 1e-8 56 | 57 | 58 | def _shift_subtitle_in_time(subtitle: Subtitle, seconds): 59 | subtitle.start_time += seconds 60 | subtitle.end_time += seconds 61 | for word in subtitle.word_list: 62 | word.approximate_word_time += seconds 63 | -------------------------------------------------------------------------------- /suber/constants.py: -------------------------------------------------------------------------------- 1 | # As used in MuST-Cinema corpus: https://ict.fbk.eu/must-cinema/ 2 | END_OF_LINE_SYMBOL = "" 3 | END_OF_BLOCK_SYMBOL = "" 4 | 5 | MASK_SYMBOL = "" 6 | -------------------------------------------------------------------------------- /suber/data_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | 6 | class LineBreak(Enum): 7 | NONE = 0 8 | END_OF_LINE = 1 # represented as '' in plain text files 9 | END_OF_BLOCK = 2 # represented as '' in plain text files 10 | 11 | 12 | @dataclass 13 | class Word: 14 | string: str 15 | line_break: LineBreak = LineBreak.NONE # the line break after the word, if any 16 | 17 | 18 | @dataclass(unsafe_hash=True) 19 | # Needs to be hashable for cached edit distance in lib_ter.py, but 'approximate_word_time' is currently set after 20 | # creation within SRTFileReader, so cannot set frozen=True. 21 | # TODO: find clean solution 22 | class TimedWord(Word): 23 | subtitle_start_time: float = None 24 | subtitle_end_time: float = None 25 | approximate_word_time: float = None # usually interpolated from subtitle start and end time; for t-BLEU calculation 26 | 27 | 28 | @dataclass 29 | class Segment: 30 | word_list: List[Word] 31 | 32 | 33 | @dataclass 34 | class Subtitle(Segment): 35 | index: int 36 | start_time: float 37 | end_time: float 38 | -------------------------------------------------------------------------------- /suber/file_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_reader_base import read_input_file 2 | from .plain_file_reader import PlainFileReader 3 | from .srt_file_reader import SRTFileReader 4 | -------------------------------------------------------------------------------- /suber/file_readers/file_reader_base.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | 3 | from typing import List 4 | from io import TextIOWrapper 5 | 6 | from suber.data_types import Segment 7 | 8 | 9 | class FileReaderBase: 10 | """ 11 | Derived classes must implement self._parse_lines(). 12 | """ 13 | def __init__(self, file_name): 14 | self._file_name = file_name 15 | 16 | def read(self) -> List[Segment]: 17 | with self._open_file() as file_object: 18 | return list(self._parse_lines(file_object)) 19 | 20 | def _parse_lines(self, file_object: TextIOWrapper) -> List[Segment]: 21 | raise NotImplementedError 22 | 23 | def _open_file(self): 24 | if self._file_name.endswith(".gz"): 25 | return gzip.open(self._file_name, "rt", encoding="utf-8") 26 | else: 27 | return open(self._file_name, "r", encoding="utf-8") 28 | 29 | 30 | def read_input_file(file_name, file_format) -> List[Segment]: 31 | from suber.file_readers import PlainFileReader, SRTFileReader # here to avoid circular import 32 | 33 | if file_format == "SRT": 34 | file_reader = SRTFileReader(file_name) 35 | elif file_format == "plain": 36 | file_reader = PlainFileReader(file_name) 37 | else: 38 | raise ValueError(f"Unknown file format: {file_format}") 39 | 40 | try: 41 | segments = file_reader.read() 42 | except Exception as e: 43 | raise Exception(f"Error reading file '{file_name}'") from e 44 | 45 | return segments 46 | -------------------------------------------------------------------------------- /suber/file_readers/plain_file_reader.py: -------------------------------------------------------------------------------- 1 | from suber.file_readers.file_reader_base import FileReaderBase 2 | from suber.data_types import LineBreak, Word, Segment 3 | from suber.constants import END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL 4 | 5 | 6 | class PlainFileReader(FileReaderBase): 7 | def _parse_lines(self, file_object): 8 | segments = [] 9 | 10 | is_first_line = True 11 | for line in file_object: 12 | if is_first_line: 13 | if line.startswith('\ufeff'): 14 | line = line[len('\ufeff'):] # remove byte order mark (BOM) 15 | is_first_line = False 16 | words = line.split() 17 | 18 | word_list = [] 19 | for word in words: 20 | if word in (END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL): 21 | if not word_list: 22 | continue # ignore line break symbol at the start of the line 23 | else: 24 | word_list[-1].line_break = ( 25 | LineBreak.END_OF_BLOCK if word == END_OF_BLOCK_SYMBOL else LineBreak.END_OF_LINE) 26 | else: 27 | word_list.append(Word(string=word)) 28 | 29 | segments.append(Segment(word_list=word_list)) 30 | 31 | return segments 32 | -------------------------------------------------------------------------------- /suber/file_readers/srt_file_reader.py: -------------------------------------------------------------------------------- 1 | import re 2 | import datetime 3 | import numpy 4 | 5 | 6 | from suber.file_readers.file_reader_base import FileReaderBase 7 | from suber.data_types import LineBreak, TimedWord, Subtitle 8 | 9 | 10 | class SRTFormatError(Exception): 11 | pass 12 | 13 | 14 | class SRTFileReader(FileReaderBase): 15 | allowed_time_formats = { 16 | "iso": r"\d+:\d+:\d+\.\d+", 17 | "iso_with_comma": r"\d+:\d+:\d+,\d+", 18 | "seconds": r"^\d+(\.\d+)?$", 19 | } 20 | 21 | def _parse_lines(self, file_object): 22 | subtitles = [] 23 | 24 | subtitle_index = None 25 | start_time, end_time = None, None 26 | word_list = None 27 | 28 | for line in file_object: 29 | line = line.strip() 30 | 31 | if subtitle_index is None: 32 | # We expect this line to be the subtitle index. Additional empty line is okay too. 33 | if line: 34 | try: 35 | subtitle_index = int(line.replace("\ufeff", "")) 36 | except ValueError as e: 37 | raise SRTFormatError(f"Tried to read subtitle index from '{line}' but failed.") from e 38 | 39 | elif start_time is None: 40 | # We expect this line to be the time string. Additional empty line is okay too. 41 | if line: 42 | start_time, end_time = self._parse_time_stamp(line) 43 | if end_time < start_time: 44 | raise SRTFormatError(f"End time {end_time} is before start time {start_time}.") 45 | 46 | if subtitles and subtitles[-1].end_time > start_time: 47 | start_time_string = line.split()[0] 48 | if start_time < subtitles[-1].start_time: 49 | raise SRTFormatError("Subtitles must appear ordered according to their start time, " 50 | f"violated by subtitle at '{start_time_string}'.") 51 | 52 | assert word_list is None 53 | word_list = [] # start collecting words 54 | 55 | elif line: 56 | # We expect this line to contain subtitle text. 57 | assert subtitle_index is not None 58 | assert start_time is not None 59 | assert end_time is not None 60 | 61 | # We don't consider formatting tags , , etc. in the evaluation. 62 | # TODO: maybe we want this regex to cover more cases 63 | line = re.sub(']>', '', line) 64 | 65 | word_list.extend([ 66 | TimedWord( 67 | string=word, 68 | subtitle_start_time=start_time, 69 | subtitle_end_time=end_time) 70 | for word in line.split()]) 71 | 72 | if word_list: 73 | word_list[-1].line_break = LineBreak.END_OF_LINE 74 | 75 | else: 76 | # This is an empty line after lines of subtitle text which ends the current subtitle. 77 | assert word_list is not None 78 | assert start_time is not None 79 | assert end_time is not None 80 | 81 | if word_list: # might be an empty subtitle 82 | word_list[-1].line_break = LineBreak.END_OF_BLOCK 83 | 84 | self._set_approximate_word_times(word_list, start_time, end_time) 85 | 86 | subtitles.append( 87 | Subtitle(word_list=word_list, index=subtitle_index, start_time=start_time, end_time=end_time)) 88 | 89 | subtitle_index = None 90 | start_time, end_time = None, None 91 | word_list = None 92 | 93 | if word_list is not None: 94 | # handle last subtitle 95 | assert subtitle_index is not None 96 | assert start_time is not None and end_time is not None 97 | 98 | if word_list: # might be an empty subtitle 99 | word_list[-1].line_break = LineBreak.END_OF_BLOCK 100 | 101 | self._set_approximate_word_times(word_list, start_time, end_time) 102 | 103 | subtitles.append( 104 | Subtitle(word_list=word_list, index=subtitle_index, start_time=start_time, end_time=end_time)) 105 | 106 | return subtitles 107 | 108 | @classmethod 109 | def _set_approximate_word_times(cls, word_list, start_time, end_time): 110 | """ 111 | Linearly interpolates word times from the subtitle start and end time as described in 112 | https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf 113 | """ 114 | # Remove small margin to guarantee the first and last word will always be counted as within the subtitle. 115 | epsilon = 1e-8 116 | start_time = start_time + epsilon 117 | end_time = end_time - epsilon 118 | 119 | num_words = len(word_list) 120 | duration = end_time - start_time 121 | assert duration >= 0 122 | 123 | approximate_word_times = numpy.linspace(start=start_time, stop=end_time, num=num_words) 124 | for word_time, word in zip(approximate_word_times, word_list): 125 | word.approximate_word_time = word_time 126 | 127 | @classmethod 128 | def _parse_time_stamp(cls, time_stamp): 129 | time_stamp_tokens = time_stamp.split() 130 | if len(time_stamp_tokens) != 3 or not time_stamp_tokens[1].endswith("->"): 131 | raise SRTFormatError(f"Could not parse subtitle times from '{time_stamp}'.") 132 | 133 | start_time = cls._seconds_from_time_code(time_stamp_tokens[0]) 134 | end_time = cls._seconds_from_time_code(time_stamp_tokens[2]) 135 | 136 | return start_time, end_time 137 | 138 | @classmethod 139 | def _seconds_from_time_code(cls, time_code): 140 | detected_time_format = None 141 | for time_format_name, time_format_regex in cls.allowed_time_formats.items(): 142 | if re.match(time_format_regex, time_code): 143 | detected_time_format = time_format_name 144 | break 145 | 146 | if not detected_time_format: 147 | raise SRTFormatError(f"Could not detect format of time code '{time_code}'.") 148 | 149 | try: 150 | if detected_time_format == "seconds": 151 | seconds = float(time_code) 152 | else: 153 | assert detected_time_format in ["iso", "iso_with_comma"] 154 | datetime_format_string = "%H:%M:%S,%f" if detected_time_format == "iso_with_comma" else "%H:%M:%S.%f" 155 | 156 | datetime_object = datetime.datetime.strptime(time_code, datetime_format_string) 157 | # Set year to 1 (default is 1900) to set this time in relation to datetime.datetime.min. 158 | datetime_object = datetime_object.replace(year=1) 159 | 160 | seconds = (datetime_object - datetime.datetime.min).total_seconds() 161 | except Exception as e: 162 | raise SRTFormatError(f"Could not convert '{time_code}' to seconds. " 163 | f"Tried to read it as format '{detected_time_format}'.") from e 164 | 165 | return seconds 166 | -------------------------------------------------------------------------------- /suber/hyp_to_ref_alignment/__init__.py: -------------------------------------------------------------------------------- 1 | from .time_alignment import time_align_hypothesis_to_reference 2 | from .levenshtein_alignment import levenshtein_align_hypothesis_to_reference 3 | -------------------------------------------------------------------------------- /suber/hyp_to_ref_alignment/levenshtein_alignment.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import string 3 | from itertools import zip_longest 4 | from typing import List, Tuple 5 | 6 | from suber import lib_levenshtein 7 | from suber.data_types import Segment 8 | 9 | 10 | def levenshtein_align_hypothesis_to_reference(hypothesis: List[Segment], reference: List[Segment]) -> List[Segment]: 11 | """ 12 | Runs the Levenshtein algorithm to get the minimal set of edit operations to convert the full list of hypothesis 13 | words into the full list of reference words. The edit operations implicitly define an alignment between hypothesis 14 | and reference words. Using this alignment, the hypotheses are re-segmented to match the reference segmentation. 15 | """ 16 | 17 | remove_punctuation_table = str.maketrans('', '', string.punctuation) 18 | 19 | def normalize_word(word): 20 | """ 21 | Lower-cases and removes punctuation as this increases the alignment accuracy. 22 | """ 23 | word = word.lower() 24 | word_without_punctuation = word.translate(remove_punctuation_table) 25 | 26 | if not word_without_punctuation: 27 | return word # keep tokens that are purely punctuation 28 | 29 | return word_without_punctuation 30 | 31 | all_reference_word_strings = [normalize_word(word.string) for segment in reference for word in segment.word_list] 32 | all_hypothesis_word_strings = [normalize_word(word.string) for segment in hypothesis for word in segment.word_list] 33 | 34 | all_hypothesis_words = [word for segment in hypothesis for word in segment.word_list] 35 | 36 | reference_string, hypothesis_string = _map_words_to_characters( 37 | all_reference_word_strings, all_hypothesis_word_strings) 38 | 39 | opcodes = lib_levenshtein.opcodes(reference_string, hypothesis_string) 40 | 41 | reference_segment_lengths = [len(segment.word_list) for segment in reference] 42 | reference_segment_boundary_indices = numpy.cumsum(reference_segment_lengths) 43 | current_segment_index = 0 44 | aligned_hypothesis_word_lists = [[] for _ in reference] 45 | 46 | for opcode_tuple in opcodes: 47 | edit_operation = opcode_tuple[0] 48 | hypothesis_position_range = range(opcode_tuple[3], opcode_tuple[4]) 49 | reference_position_range = range(opcode_tuple[1], opcode_tuple[2]) 50 | 51 | if edit_operation in ("equal", "replace"): 52 | assert len(hypothesis_position_range) == len(reference_position_range) 53 | elif edit_operation == "insert": 54 | assert len(reference_position_range) == 0 55 | elif edit_operation == "delete": 56 | assert len(hypothesis_position_range) == 0 57 | else: 58 | assert False, f"Invalid edit operation '{edit_operation}'." 59 | 60 | # 'zip_longest' is a "clever" way to unify the different cases: for 'equal' and 'replace' we indeed have to 61 | # iterate through hypothesis and reference position in parallel, for 'insert' and 'delete' either 62 | # 'hypothesis_position' or 'reference_position' will be None in the loop. 63 | for hypothesis_position, reference_position in zip_longest(hypothesis_position_range, reference_position_range): 64 | 65 | # Update current segment index depending on current reference position. 66 | if (reference_position is not None 67 | and reference_position >= reference_segment_boundary_indices[current_segment_index]): 68 | 69 | assert reference_position == reference_segment_boundary_indices[current_segment_index], ( 70 | "Bug: missing reference position in edit operations.") 71 | current_segment_index += 1 72 | 73 | # If there are empty segments in the reference, we get double entries in 74 | # 'reference_segment_boundary_indices' (because the empty segment ends at the same word index as the 75 | # previous segment). Skip these empty segments, we don't want to assign any hypothesis words to them. 76 | while (current_segment_index < len(reference_segment_boundary_indices) 77 | and reference_segment_boundary_indices[current_segment_index] 78 | == reference_segment_boundary_indices[current_segment_index - 1]): 79 | current_segment_index += 1 80 | 81 | # Add hypothesis word to current segment in case of 'equal', 'replace' or 'insert' operation. 82 | if hypothesis_position is not None: 83 | word = all_hypothesis_words[hypothesis_position] 84 | aligned_hypothesis_word_lists[current_segment_index].append(word) 85 | 86 | aligned_hypothesis = [Segment(word_list=word_list) for word_list in aligned_hypothesis_word_lists] 87 | 88 | return aligned_hypothesis 89 | 90 | 91 | def _map_words_to_characters(reference_words: List[str], hypothesis_words: List[str]) -> Tuple[str, str]: 92 | """ 93 | The Levenshtein module operates on strings, not list of strings. Therefore we map words to characters here. 94 | Inspired by https://github.com/jitsi/jiwer/blob/master/jiwer/measures.py. 95 | """ 96 | unique_words = set(reference_words + hypothesis_words) 97 | vocabulary = dict(zip(unique_words, range(len(unique_words)))) 98 | 99 | reference_string = "".join(chr(vocabulary[word] + 32) for word in reference_words) 100 | hypothesis_string = "".join(chr(vocabulary[word] + 32) for word in hypothesis_words) 101 | 102 | return reference_string, hypothesis_string 103 | -------------------------------------------------------------------------------- /suber/hyp_to_ref_alignment/time_alignment.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from typing import List 4 | from suber.data_types import Segment, Subtitle 5 | 6 | 7 | def time_align_hypothesis_to_reference(hypothesis: List[Segment], reference: List[Subtitle]) -> List[Subtitle]: 8 | """ 9 | Re-segments the hypothesis segments according to the reference subtitle timings. The output hypothesis subtitles 10 | will have the same time stamps as the reference, and each will contain the words whose approximate times falls into 11 | these intervals, i.e. reference_subtitle.start_time < word.approximate_word_time < reference_subtitle.end_time. 12 | Hypothesis words that do not fall into any subtitle will be dropped. 13 | """ 14 | aligned_hypothesis_word_lists = [[] for _ in reference] 15 | 16 | reference_start_times = numpy.array([subtitle.start_time for subtitle in reference]) 17 | reference_end_times = numpy.array([subtitle.end_time for subtitle in reference]) 18 | 19 | for segment in hypothesis: 20 | for word in segment.word_list: 21 | assert word.approximate_word_time is not None, "Should have been set by SRTFileReader. Is plain file used?" 22 | reference_subtitle_index = numpy.searchsorted(reference_start_times, word.approximate_word_time) - 1 23 | 24 | if reference_subtitle_index < 0: 25 | # Word is before first subtitle, drop it. 26 | continue 27 | 28 | if word.approximate_word_time < reference_end_times[reference_subtitle_index]: 29 | aligned_hypothesis_word_lists[reference_subtitle_index].append(word) 30 | 31 | aligned_hypothesis = [] 32 | 33 | for index, word_list in enumerate(aligned_hypothesis_word_lists): 34 | reference_subtitle = reference[index] 35 | subtitle = Subtitle( 36 | word_list=word_list, 37 | index=reference_subtitle.index, 38 | start_time=reference_subtitle.start_time, 39 | end_time=reference_subtitle.end_time) 40 | 41 | aligned_hypothesis.append(subtitle) 42 | 43 | return aligned_hypothesis 44 | -------------------------------------------------------------------------------- /suber/lib_levenshtein.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: MIT 2 | # Copyright (C) 2022 Max Bachmann 3 | 4 | # This code was taken from https://github.com/rapidfuzz/RapidFuzz/blob/main/src/rapidfuzz/distance/Levenshtein_py.py 5 | # https://github.com/rapidfuzz/RapidFuzz/blob/main/src/rapidfuzz/_common_py.py 6 | # and https://github.com/rapidfuzz/RapidFuzz/blob/main/src/rapidfuzz/distance/_initialize_py.py 7 | # and altered to recover the exact behavior of python-Levenshtein v0.12.0, which was our original Levenshtein 8 | # dependency but does not support Python > 3.10. In general, there are several possible alignments resulting in minimal 9 | # Levenshtein distance, and the choice of a particular alignment changed when (python-)Levenshtein started using the 10 | # rapidfuzz implementation in v0.18.0. Upgrading python-Levenshtein would therefore result in slightly different scores 11 | # for the "AS-" metrics on our end. For now, we want perfect backwards compatibility and therefore integrate our own 12 | # version of the Levenshtein code here. 13 | 14 | 15 | def _matrix(s1, s2): 16 | if not s1: 17 | return (len(s2), [], []) 18 | 19 | VP = (1 << len(s1)) - 1 20 | VN = 0 21 | currDist = len(s1) 22 | mask = 1 << (len(s1) - 1) 23 | 24 | block = {} 25 | block_get = block.get 26 | x = 1 27 | for ch1 in s1: 28 | block[ch1] = block_get(ch1, 0) | x 29 | x <<= 1 30 | 31 | matrix_VP = [] 32 | matrix_VN = [] 33 | for ch2 in s2: 34 | # Step 1: Computing D0 35 | PM_j = block_get(ch2, 0) 36 | X = PM_j 37 | D0 = (((X & VP) + VP) ^ VP) | X | VN 38 | # Step 2: Computing HP and HN 39 | HP = VN | ~(D0 | VP) 40 | HN = D0 & VP 41 | # Step 3: Computing the value D[m,j] 42 | currDist += (HP & mask) != 0 43 | currDist -= (HN & mask) != 0 44 | # Step 4: Computing Vp and VN 45 | HP = (HP << 1) | 1 46 | HN = HN << 1 47 | VP = HN | ~(D0 | HP) 48 | VN = HP & D0 49 | 50 | matrix_VP.append(VP) 51 | matrix_VN.append(VN) 52 | 53 | return (currDist, matrix_VP, matrix_VN) 54 | 55 | 56 | def distance(s1, s2): 57 | prefix_len, suffix_len = common_affix(s1, s2) 58 | s1 = s1[prefix_len : len(s1) - suffix_len] 59 | s2 = s2[prefix_len : len(s2) - suffix_len] 60 | dist, _, _ = _matrix(s1, s2) 61 | return dist 62 | 63 | 64 | def editops(s1, s2): 65 | """ 66 | Creates editops from the output of the bit-parallel rapidfuzz implementation above (edit distance matrix expressed 67 | as delta vectors), but makes the exact choices in case of ties as the original python-Levenshtein code: 68 | https://github.com/rapidfuzz/Levenshtein/blob/v0.17.0/src/Levenshtein-c/_levenshtein.c#L3205 69 | The rapidfuzz implementation prefers "insert", as the decisions can be made efficiently using the delta vectors 70 | in that case, see 71 | https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=339dec563792acb5bb2feffc53628b62bdc36329 72 | To prefer "replace" (among other differences) we need to re-calculate the actual elements of the distance matrix 73 | from the delta vectors, which kind of defeats the purpose as it makes the algorithm less efficient. But here we care 74 | more about backwards compatibility than efficiency. 75 | """ 76 | prefix_len, suffix_len = common_affix(s1, s2) 77 | s1 = s1[prefix_len : len(s1) - suffix_len] 78 | s2 = s2[prefix_len : len(s2) - suffix_len] 79 | dist, VP, VN = _matrix(s1, s2) 80 | 81 | if dist == 0: 82 | return [] 83 | 84 | editop_list = [None] * dist 85 | col = len(s1) 86 | row = len(s2) 87 | direction = 0 88 | 89 | while row != 0 and col != 0: 90 | masked_VP = VP[row - 1] & ((1 << col) - 1) 91 | masked_VN = VN[row - 1] & ((1 << col) - 1) 92 | current_distance = masked_VP.bit_count() - masked_VN.bit_count() + row 93 | 94 | masked_VP = VP[row - 1] & ((1 << (col - 1)) - 1) 95 | masked_VN = VN[row - 1] & ((1 << (col - 1)) - 1) 96 | deletion_distance = masked_VP.bit_count() - masked_VN.bit_count() + row 97 | 98 | if row > 1: 99 | masked_VP = VP[row - 2] & ((1 << (col - 1)) - 1) 100 | masked_VN = VN[row - 2] & ((1 << (col - 1)) - 1) 101 | replace_distance = masked_VP.bit_count() - masked_VN.bit_count() + row - 1 102 | 103 | masked_VP = VP[row - 2] & ((1 << col) - 1) 104 | masked_VN = VN[row - 2] & ((1 << col) - 1) 105 | insertion_distance = masked_VP.bit_count() - masked_VN.bit_count() + row - 1 106 | 107 | else: 108 | replace_distance = col - 1 109 | insertion_distance = col 110 | 111 | if direction == -1 and current_distance == insertion_distance + 1: 112 | dist -= 1 113 | row -= 1 114 | direction = -1 115 | editop_list[dist] = ("insert", col + prefix_len, row + prefix_len) 116 | 117 | elif direction == 1 and current_distance == deletion_distance + 1: 118 | dist -= 1 119 | col -= 1 120 | direction = 1 121 | editop_list[dist] = ("delete", col + prefix_len, row + prefix_len) 122 | 123 | elif current_distance == replace_distance and s1[col - 1] == s2[row - 1]: 124 | col -= 1 125 | row -= 1 126 | direction = 0 127 | 128 | elif current_distance == replace_distance + 1: 129 | col -= 1 130 | row -= 1 131 | dist -= 1 132 | direction = 0 133 | editop_list[dist] = ("replace", col + prefix_len, row + prefix_len) 134 | 135 | elif direction == 0 and current_distance == insertion_distance + 1: 136 | dist -= 1 137 | row -= 1 138 | direction = -1 139 | editop_list[dist] = ("insert", col + prefix_len, row + prefix_len) 140 | 141 | elif direction == 0 and current_distance == deletion_distance + 1: 142 | dist -= 1 143 | col -= 1 144 | direction = 1 145 | editop_list[dist] = ("delete", col + prefix_len, row + prefix_len) 146 | 147 | else: 148 | assert False, "Bug while back-tracing cost matrix." 149 | 150 | assert dist >= 0, "Bug: distance differs from number of edit ops computed during back-tracing." 151 | 152 | while col != 0: 153 | dist -= 1 154 | col -= 1 155 | editop_list[dist] = ("delete", col + prefix_len, row + prefix_len) 156 | 157 | while row != 0: 158 | dist -= 1 159 | row -= 1 160 | editop_list[dist] = ("insert", col + prefix_len, row + prefix_len) 161 | 162 | assert dist == 0, "Bug: distance differs from number of edit ops computed during back-tracing." 163 | return editop_list 164 | 165 | 166 | def opcodes(s1, s2): 167 | editops_ = editops(s1, s2) 168 | 169 | src_len = len(s1) 170 | dest_len = len(s2) 171 | 172 | blocks = [] 173 | src_pos = 0 174 | dest_pos = 0 175 | i = 0 176 | while i < len(editops_): 177 | if src_pos < editops_[i][1] or dest_pos < editops_[i][2]: 178 | blocks.append( 179 | ( 180 | "equal", 181 | src_pos, 182 | editops_[i][1], 183 | dest_pos, 184 | editops_[i][2], 185 | ) 186 | ) 187 | src_pos = editops_[i][1] 188 | dest_pos = editops_[i][2] 189 | 190 | src_begin = src_pos 191 | dest_begin = dest_pos 192 | tag = editops_[i][0] 193 | while ( 194 | i < len(editops_) 195 | and editops_[i][0] == tag 196 | and src_pos == editops_[i][1] 197 | and dest_pos == editops_[i][2] 198 | ): 199 | if tag == "replace": 200 | src_pos += 1 201 | dest_pos += 1 202 | elif tag == "insert": 203 | dest_pos += 1 204 | elif tag == "delete": 205 | src_pos += 1 206 | 207 | i += 1 208 | 209 | blocks.append((tag, src_begin, src_pos, dest_begin, dest_pos)) 210 | 211 | if src_pos < src_len or dest_pos < dest_len: 212 | blocks.append(("equal", src_pos, src_len, dest_pos, dest_len)) 213 | 214 | return blocks 215 | 216 | 217 | def common_prefix(s1: str, s2: str) -> int: 218 | prefix_len = 0 219 | for ch1, ch2 in zip(s1, s2): 220 | if ch1 != ch2: 221 | break 222 | 223 | prefix_len += 1 224 | 225 | return prefix_len 226 | 227 | 228 | def common_suffix(s1: str, s2: str) -> int: 229 | suffix_len = 0 230 | for ch1, ch2 in zip(reversed(s1), reversed(s2)): 231 | if ch1 != ch2: 232 | break 233 | 234 | suffix_len += 1 235 | 236 | return suffix_len 237 | 238 | 239 | def common_affix(s1: str, s2: str) -> tuple[int, int]: 240 | prefix_len = common_prefix(s1, s2) 241 | suffix_len = common_suffix(s1[prefix_len:], s2[prefix_len:]) 242 | return (prefix_len, suffix_len) -------------------------------------------------------------------------------- /suber/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apptek/SubER/d174f422ab29edff23af58954520480ad56138c7/suber/metrics/__init__.py -------------------------------------------------------------------------------- /suber/metrics/cer.py: -------------------------------------------------------------------------------- 1 | import string 2 | from typing import List 3 | 4 | from suber import lib_levenshtein 5 | from suber.data_types import Segment 6 | from suber.utilities import segment_to_string 7 | 8 | 9 | def calculate_character_error_rate(hypothesis: List[Segment], reference: List[Segment], metric="CER") -> float: 10 | assert len(hypothesis) == len(reference), ( 11 | "Number of hypothesis segments does not match reference, alignment step missing?") 12 | 13 | hypothesis_strings = [segment_to_string(segment) for segment in hypothesis] 14 | reference_strings = [segment_to_string(segment) for segment in reference] 15 | 16 | if metric != "CER-cased": 17 | remove_punctuation_table = str.maketrans('', '', string.punctuation) 18 | 19 | def normalize_string(string): 20 | string = string.translate(remove_punctuation_table) 21 | # Ellipsis is a common character in subtitles which is not included in string.punctuation. 22 | string = string.replace('…', '') 23 | string = string.lower() 24 | return string 25 | 26 | hypothesis_strings = [normalize_string(string) for string in hypothesis_strings] 27 | reference_strings = [normalize_string(string) for string in reference_strings] 28 | 29 | num_edits = 0 30 | num_reference_characters = 0 31 | for hypothesis_string, reference_string, in zip(hypothesis_strings, reference_strings): 32 | num_edits += lib_levenshtein.distance(hypothesis_string, reference_string) 33 | num_reference_characters += len(reference_string) 34 | 35 | if num_reference_characters: 36 | cer_score = num_edits / num_reference_characters 37 | else: 38 | cer_score = 1.0 if num_edits else 0.0 39 | 40 | return round(cer_score * 100, 3) 41 | -------------------------------------------------------------------------------- /suber/metrics/jiwer_interface.py: -------------------------------------------------------------------------------- 1 | import jiwer 2 | import functools 3 | from typing import List 4 | 5 | from sacrebleu.tokenizers.tokenizer_ter import TercomTokenizer 6 | 7 | from suber.data_types import Segment 8 | from suber.utilities import segment_to_string, get_segment_to_string_opts_from_metric 9 | 10 | 11 | def calculate_word_error_rate(hypothesis: List[Segment], reference: List[Segment], metric="WER", 12 | score_break_at_segment_end=True) -> float: 13 | 14 | assert len(hypothesis) == len(reference), ( 15 | "Number of hypothesis segments does not match reference, alignment step missing?") 16 | 17 | if metric == "WER-cased": 18 | transformations = jiwer.Compose([ 19 | # Note: the original release used no tokenization here. We find this change to have a minor positive effect 20 | # on correlation with post-edit effort (-0.657 vs. -0.650 in Table 1, row 2, "Combined" in our paper.) 21 | TercomTokenize(), 22 | jiwer.ReduceToListOfListOfWords(), 23 | ]) 24 | metric = "WER" 25 | 26 | else: 27 | transformations = jiwer.Compose([ 28 | jiwer.ToLowerCase(), 29 | jiwer.RemovePunctuation(), 30 | # Ellipsis is a common character in subtitles that older jiwer versions would not remove by default. 31 | jiwer.RemoveSpecificWords(['…']), 32 | jiwer.ReduceToListOfListOfWords(), 33 | ]) 34 | 35 | include_breaks, mask_words, metric = get_segment_to_string_opts_from_metric(metric) 36 | assert metric == "WER" 37 | 38 | segment_to_string_ = functools.partial( 39 | segment_to_string, include_line_breaks=include_breaks, mask_all_words=mask_words, 40 | include_last_break=score_break_at_segment_end) 41 | 42 | hypothesis_strings = [segment_to_string_(segment) for segment in hypothesis] 43 | reference_strings = [segment_to_string_(segment) for segment in reference] 44 | 45 | wer_score = jiwer.wer( 46 | reference_strings, 47 | hypothesis_strings, 48 | reference_transform=transformations, 49 | hypothesis_transform=transformations) 50 | 51 | return round(wer_score * 100, 3) 52 | 53 | 54 | class TercomTokenize(jiwer.AbstractTransform): 55 | def __init__(self): 56 | self.tokenizer = TercomTokenizer(normalized=True, no_punct=False, case_sensitive=True) 57 | 58 | def process_string(self, s: str): 59 | return self.tokenizer(s) 60 | -------------------------------------------------------------------------------- /suber/metrics/length_ratio.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from suber.data_types import Segment 3 | 4 | from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a 5 | 6 | 7 | def calculate_length_ratio(hypothesis: List[Segment], reference: List[Segment]) -> float: 8 | all_hypothesis_words = [word.string for segment in hypothesis for word in segment.word_list] 9 | all_reference_words = [word.string for segment in reference for word in segment.word_list] 10 | 11 | full_hypothesis_string = " ".join(all_hypothesis_words) 12 | full_reference_string = " ".join(all_reference_words) 13 | 14 | # Default tokenizer used for BLEU calculation in SacreBLEU, so length ratio we calculate here should correspond 15 | # to the "ratio" printed by SacreBLEU. 16 | tokenizer = Tokenizer13a() 17 | 18 | num_tokens_hypothesis = len(tokenizer(full_hypothesis_string).split()) 19 | num_tokens_reference = len(tokenizer(full_reference_string).split()) 20 | 21 | length_ratio = num_tokens_hypothesis / num_tokens_reference if num_tokens_reference else 0.0 22 | 23 | return round(length_ratio * 100, 3) 24 | -------------------------------------------------------------------------------- /suber/metrics/lib_ter.py: -------------------------------------------------------------------------------- 1 | """This module implements various utility functions for the TER metric.""" 2 | 3 | # This file was copied from sacrebleu and modified: 4 | # https://github.com/mjpost/sacrebleu/blob/078c440168c6adc89ba75fe6d63f0d922d42bcfe/sacrebleu/metrics/lib_ter.py 5 | 6 | # Copyright 2020 Memsource 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | 21 | import math 22 | from typing import List, Tuple, Dict 23 | 24 | from suber.data_types import TimedWord 25 | from suber.constants import END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL 26 | from suber.metrics.suber_statistics import SubERStatisticsCollector 27 | 28 | 29 | _COST_INS = 1 30 | _COST_DEL = 1 31 | _COST_SUB = 1 32 | 33 | # Tercom-inspired limits 34 | _MAX_SHIFT_SIZE = 10 35 | _MAX_SHIFT_DIST = 50 36 | _BEAM_WIDTH = 100 37 | 38 | # Our own limits 39 | _MAX_CACHE_SIZE = 10000 40 | _MAX_SHIFT_CANDIDATES = 1000 41 | _INT_INFINITY = int(1e16) 42 | 43 | _OP_INS = 'i' 44 | _OP_DEL = 'd' 45 | _OP_NOP = ' ' 46 | _OP_SUB = 's' 47 | _OP_UNDEF = 'x' 48 | 49 | _FLIP_OPS = str.maketrans(_OP_INS + _OP_DEL, _OP_DEL + _OP_INS) 50 | 51 | 52 | def translation_edit_rate(words_hyp: List[TimedWord], words_ref: List[TimedWord], 53 | statistics_collector: SubERStatisticsCollector = None) -> Tuple[int, int]: 54 | """Calculate the translation edit rate. 55 | 56 | :param words_hyp: Tokenized translation hypothesis. 57 | :param words_ref: Tokenized reference translation. 58 | :return: tuple (number of edits, length) 59 | """ 60 | n_words_ref = len(words_ref) 61 | n_words_hyp = len(words_hyp) 62 | if n_words_ref == 0: 63 | trace = _flip_trace(_OP_DEL * n_words_hyp) # Switch to reference to hypothesis direction, see comment below. 64 | 65 | if statistics_collector: 66 | statistics_collector.add_data(trace=trace, words_ref=words_ref, words_hyp_shifted=words_hyp, num_shifts=0) 67 | 68 | # special treatment of empty refs 69 | return n_words_hyp, 0 70 | 71 | cached_ed = BeamEditDistance(words_ref) 72 | shifts = 0 73 | 74 | input_words = words_hyp 75 | checked_candidates = 0 76 | while True: 77 | # do shifts until they stop reducing the edit distance 78 | delta, new_input_words, checked_candidates = _shift( 79 | input_words, words_ref, cached_ed, checked_candidates) 80 | 81 | if checked_candidates >= _MAX_SHIFT_CANDIDATES: 82 | break 83 | 84 | if delta <= 0: 85 | break 86 | shifts += 1 87 | input_words = new_input_words 88 | 89 | edit_distance, trace = cached_ed(input_words) 90 | total_edits = shifts + edit_distance 91 | 92 | if statistics_collector: 93 | statistics_collector.add_data( 94 | # In the SubER code we always use the reference to hypothesis direction, i.e. we call an additional word 95 | # in the hypothesis an insertion, a missing word in the hypothesis a deletion. 96 | trace=_flip_trace(trace), 97 | words_ref=words_ref, 98 | words_hyp_shifted=input_words, 99 | num_shifts=shifts, 100 | ) 101 | 102 | return total_edits, n_words_ref 103 | 104 | 105 | def _is_allowed_word_alignment(word1: TimedWord, word2: TimedWord) -> bool: 106 | """ 107 | Returns whether SubER definition allows to align the two words. This is the case when they are part of subtitles 108 | that overlap in time. In addition, break tokens must not be aligned with real words. 109 | """ 110 | if ((word1.string in [END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL]) 111 | != (word2.string in [END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL])): 112 | return False 113 | 114 | return ((word1.subtitle_start_time < word2.subtitle_end_time) 115 | == (word2.subtitle_start_time < word1.subtitle_end_time)) 116 | 117 | 118 | def _is_word_match(word1: TimedWord, word2: TimedWord) -> bool: 119 | """ 120 | Returns whether SubER counts the two words as a match, meaning no edit operation needed. 121 | """ 122 | if word1.string != word2.string: 123 | return False 124 | 125 | return _is_allowed_word_alignment(word1, word2) 126 | 127 | 128 | def _shift(words_h: List[TimedWord], words_r: List[TimedWord], cached_ed, 129 | checked_candidates: int) -> Tuple[int, List[TimedWord], int]: 130 | """Attempt to shift words in hypothesis to match reference. 131 | 132 | Returns the shift that reduces the edit distance the most. 133 | 134 | Note that the filtering of possible shifts and shift selection are heavily 135 | based on somewhat arbitrary heuristics. The code here follows as closely 136 | as possible the logic in Tercom, not always justifying the particular design 137 | choices. 138 | 139 | :param words_h: Hypothesis. 140 | :param words_r: Reference. 141 | :param cached_ed: Cached edit distance. 142 | :param checked_candidates: Number of shift candidates that were already 143 | evaluated. 144 | :return: (score, shifted_words, checked_candidates). Best shift and updated 145 | number of evaluated shift candidates. 146 | """ 147 | pre_score, inv_trace = cached_ed(words_h) 148 | 149 | # to get alignment, we pretend we are rewriting reference into hypothesis, 150 | # so we need to flip the trace of edit operations 151 | trace = _flip_trace(inv_trace) 152 | align, ref_err, hyp_err = trace_to_alignment(trace) 153 | 154 | best = None 155 | 156 | for start_h, start_r, length in _find_shifted_pairs(words_h, words_r): 157 | # don't do the shift unless both the hypothesis was wrong and the 158 | # reference doesn't match hypothesis at the target position 159 | if sum(hyp_err[start_h: start_h + length]) == 0: 160 | continue 161 | 162 | if sum(ref_err[start_r: start_r + length]) == 0: 163 | continue 164 | 165 | # don't try to shift within the subsequence 166 | if start_h <= align[start_r] < start_h + length: 167 | continue 168 | 169 | prev_idx = -1 170 | for offset in range(-1, length): 171 | if start_r + offset == -1: 172 | idx = 0 # insert before the beginning 173 | elif start_r + offset in align: 174 | # Unlike Tercom which inserts *after* the index, we insert 175 | # *before* the index. 176 | idx = align[start_r + offset] + 1 177 | else: 178 | break # offset is out of bounds => aims past reference 179 | 180 | if idx == prev_idx: 181 | continue # skip idx if already tried 182 | 183 | prev_idx = idx 184 | 185 | shifted_words = _perform_shift(words_h, start_h, length, idx) 186 | assert(len(shifted_words) == len(words_h)) 187 | 188 | # Elements of the tuple are designed to replicate Tercom ranking 189 | # of shifts: 190 | candidate = ( 191 | pre_score - cached_ed(shifted_words)[0], # highest score first 192 | length, # then, longest match first 193 | -start_h, # then, earliest match first 194 | -idx, # then, earliest target position first 195 | shifted_words, 196 | ) 197 | 198 | checked_candidates += 1 199 | 200 | if not best or candidate > best: 201 | best = candidate 202 | 203 | if checked_candidates >= _MAX_SHIFT_CANDIDATES: 204 | break 205 | 206 | if not best: 207 | return 0, words_h, checked_candidates 208 | else: 209 | best_score, _, _, _, shifted_words = best 210 | return best_score, shifted_words, checked_candidates 211 | 212 | 213 | def _perform_shift(words: List[TimedWord], start: int, length: int, target: int) -> List[TimedWord]: 214 | """Perform a shift in `words` from `start` to `target`. 215 | 216 | :param words: Words to shift. 217 | :param start: Where from. 218 | :param length: How many words. 219 | :param target: Where to. 220 | :return: Shifted words. 221 | """ 222 | if target < start: 223 | # shift before previous position 224 | return words[:target] + words[start: start + length] \ 225 | + words[target: start] + words[start + length:] 226 | elif target > start + length: 227 | # shift after previous position 228 | return words[:start] + words[start + length: target] \ 229 | + words[start: start + length] + words[target:] 230 | else: 231 | # shift within the shifted string 232 | return words[:start] + words[start + length: length + target] \ 233 | + words[start: start + length] + words[length + target:] 234 | 235 | 236 | def _find_shifted_pairs(words_h: List[TimedWord], words_r: List[TimedWord]): 237 | """Find matching word sub-sequences in two lists of words. 238 | 239 | Ignores sub-sequences starting at the same position. 240 | 241 | :param words_h: First word list. 242 | :param words_r: Second word list. 243 | :return: Yields tuples of (h_start, r_start, length) such that: 244 | words_h[h_start:h_start+length] = words_r[r_start:r_start+length] 245 | """ 246 | n_words_h = len(words_h) 247 | n_words_r = len(words_r) 248 | for start_h in range(n_words_h): 249 | for start_r in range(n_words_r): 250 | # this is slightly different from what tercom does but this should 251 | # really only kick in in degenerate cases 252 | if abs(start_r - start_h) > _MAX_SHIFT_DIST: 253 | continue 254 | 255 | length = 0 256 | while _is_word_match(words_h[start_h + length], words_r[start_r + length]) and length < _MAX_SHIFT_SIZE: 257 | length += 1 258 | 259 | yield start_h, start_r, length 260 | 261 | # If one sequence is consumed, stop processing 262 | if n_words_h == start_h + length or n_words_r == start_r + length: 263 | break 264 | 265 | 266 | def _flip_trace(trace): 267 | """Flip the trace of edit operations. 268 | 269 | Instead of rewriting a->b, get a recipe for rewriting b->a. 270 | 271 | Simply flips insertions and deletions. 272 | """ 273 | return trace.translate(_FLIP_OPS) 274 | 275 | 276 | def trace_to_alignment(trace: str) -> Tuple[Dict, List, List]: 277 | """Transform trace of edit operations into an alignment of the sequences. 278 | 279 | :param trace: Trace of edit operations (' '=no change or 's'/'i'/'d'). 280 | :return: Alignment, error positions in reference, error positions in hypothesis. 281 | """ 282 | pos_hyp = -1 283 | pos_ref = -1 284 | hyp_err = [] 285 | ref_err = [] 286 | align = {} 287 | 288 | # we are rewriting a into b 289 | for op in trace: 290 | if op == _OP_NOP: 291 | pos_hyp += 1 292 | pos_ref += 1 293 | align[pos_ref] = pos_hyp 294 | hyp_err.append(0) 295 | ref_err.append(0) 296 | elif op == _OP_SUB: 297 | pos_hyp += 1 298 | pos_ref += 1 299 | align[pos_ref] = pos_hyp 300 | hyp_err.append(1) 301 | ref_err.append(1) 302 | elif op == _OP_INS: 303 | pos_hyp += 1 304 | hyp_err.append(1) 305 | elif op == _OP_DEL: 306 | pos_ref += 1 307 | align[pos_ref] = pos_hyp 308 | ref_err.append(1) 309 | else: 310 | raise Exception(f"unknown operation {op!r}") 311 | 312 | return align, ref_err, hyp_err 313 | 314 | 315 | class BeamEditDistance: 316 | """Edit distance with several features required for TER calculation. 317 | 318 | * internal cache 319 | * "beam" search 320 | * tracking of edit operations 321 | 322 | The internal self._cache works like this: 323 | 324 | Keys are words of the hypothesis. Values are tuples (next_node, row) where: 325 | 326 | * next_node is the cache for the next word in the sequence 327 | * row is the stored row of the edit distance matrix 328 | 329 | Effectively, caching allows to skip several rows in the edit distance 330 | matrix calculation and instead, to initialize the computation with the last 331 | matching matrix row. 332 | 333 | Beam search, as implemented here, only explores a fixed-size sub-row of 334 | candidates around the matrix diagonal (more precisely, it's a 335 | "pseudo"-diagonal since we take the ratio of sequence lengths into account). 336 | 337 | Tracking allows to reconstruct the optimal sequence of edit operations. 338 | 339 | :param words_ref: A list of reference tokens. 340 | """ 341 | def __init__(self, words_ref: List[TimedWord]): 342 | """`BeamEditDistance` initializer.""" 343 | self._words_ref = words_ref 344 | self._n_words_ref = len(self._words_ref) 345 | 346 | # first row corresponds to insertion operations of the reference, 347 | # so we do 1 edit operation per reference word 348 | self._initial_row = [(i * _COST_INS, _OP_INS) 349 | for i in range(self._n_words_ref + 1)] 350 | 351 | self._cache = {} # type: Dict[str, Tuple] 352 | self._cache_size = 0 353 | 354 | # Precomputed empty matrix row. Contains infinities so that beam search 355 | # avoids using the uninitialized cells. 356 | self._empty_row = [(_INT_INFINITY, _OP_UNDEF)] * (self._n_words_ref + 1) 357 | 358 | def __call__(self, words_hyp: List[TimedWord]) -> Tuple[int, str]: 359 | """Calculate edit distance between self._words_ref and the hypothesis. 360 | 361 | Uses cache to skip some of the computation. 362 | 363 | :param words_hyp: Words in translation hypothesis. 364 | :return: Edit distance score. 365 | """ 366 | 367 | # skip initial words in the hypothesis for which we already know the 368 | # edit distance 369 | start_position, dist = self._find_cache(words_hyp) 370 | 371 | # calculate the rest of the edit distance matrix 372 | edit_distance, newly_created_matrix, trace = self._edit_distance( 373 | words_hyp, start_position, dist) 374 | 375 | # update our cache with the newly calculated rows 376 | self._add_cache(words_hyp, newly_created_matrix) 377 | 378 | return edit_distance, trace 379 | 380 | def _edit_distance(self, words_h: List[TimedWord], start_h: int, 381 | cache: List[List[Tuple[int, TimedWord]]]) -> Tuple[int, List, TimedWord]: 382 | """Actual edit distance calculation. 383 | 384 | Can be initialized with the last cached row and a start position in 385 | the hypothesis that it corresponds to. 386 | 387 | :param words_h: Words in translation hypothesis. 388 | :param start_h: Position from which to start the calculation. 389 | (This is zero if no cache match was found.) 390 | :param cache: Precomputed rows corresponding to edit distance matrix 391 | before `start_h`. 392 | :return: Edit distance value, newly computed rows to update the 393 | cache, trace. 394 | """ 395 | 396 | n_words_h = len(words_h) 397 | 398 | # initialize the rest of the matrix with infinite edit distances 399 | rest_empty = [list(self._empty_row) 400 | for _ in range(n_words_h - start_h)] 401 | 402 | dist = cache + rest_empty 403 | 404 | assert len(dist) == n_words_h + 1 405 | 406 | length_ratio = self._n_words_ref / n_words_h if words_h else 1 407 | 408 | # in some crazy sentences, the difference in length is so large that 409 | # we may end up with zero overlap with previous row 410 | if _BEAM_WIDTH < length_ratio / 2: 411 | beam_width = math.ceil(length_ratio / 2 + _BEAM_WIDTH) 412 | else: 413 | beam_width = _BEAM_WIDTH 414 | 415 | # calculate the Levenshtein distance 416 | for i in range(start_h + 1, n_words_h + 1): 417 | pseudo_diag = math.floor(i * length_ratio) 418 | min_j = max(0, pseudo_diag - beam_width) 419 | max_j = min(self._n_words_ref + 1, pseudo_diag + beam_width) 420 | 421 | if i == n_words_h: 422 | max_j = self._n_words_ref + 1 423 | 424 | for j in range(min_j, max_j): 425 | if j == 0: 426 | dist[i][j] = (dist[i - 1][j][0] + _COST_DEL, _OP_DEL) 427 | else: 428 | if _is_word_match(words_h[i - 1], self._words_ref[j - 1]): 429 | cost_sub = 0 430 | op_sub = _OP_NOP 431 | else: 432 | if _is_allowed_word_alignment(words_h[i - 1], self._words_ref[j - 1]): 433 | cost_sub = _COST_SUB 434 | else: 435 | # No substitution allowed if words are not time-aligned. 436 | cost_sub = _INT_INFINITY 437 | op_sub = _OP_SUB 438 | 439 | # Tercom prefers no-op/sub, then insertion, then deletion. 440 | # But since we flip the trace and compute the alignment from 441 | # the inverse, we need to swap order of insertion and 442 | # deletion in the preference. 443 | ops = ( 444 | (dist[i - 1][j - 1][0] + cost_sub, op_sub), 445 | (dist[i - 1][j][0] + _COST_DEL, _OP_DEL), 446 | (dist[i][j - 1][0] + _COST_INS, _OP_INS), 447 | ) 448 | 449 | for op_cost, op_name in ops: 450 | if dist[i][j][0] > op_cost: 451 | dist[i][j] = op_cost, op_name 452 | 453 | # get the trace 454 | trace = "" 455 | i = n_words_h 456 | j = self._n_words_ref 457 | 458 | while i > 0 or j > 0: 459 | op = dist[i][j][1] 460 | trace = op + trace 461 | if op in (_OP_SUB, _OP_NOP): 462 | i -= 1 463 | j -= 1 464 | elif op == _OP_INS: 465 | j -= 1 466 | elif op == _OP_DEL: 467 | i -= 1 468 | else: 469 | raise Exception(f"unknown operation {op!r}") 470 | 471 | return dist[-1][-1][0], dist[len(cache):], trace 472 | 473 | def _add_cache(self, words_hyp: List[TimedWord], mat: List[List[Tuple]]): 474 | """Add newly computed rows to cache. 475 | 476 | Since edit distance is only calculated on the hypothesis suffix that 477 | was not in cache, the number of rows in `mat` may be shorter than 478 | hypothesis length. In that case, we skip over these initial words. 479 | 480 | :param words_hyp: Hypothesis words. 481 | :param mat: Edit distance matrix rows for each position. 482 | """ 483 | if self._cache_size >= _MAX_CACHE_SIZE: 484 | return 485 | 486 | node = self._cache 487 | 488 | n_mat = len(mat) 489 | 490 | # how many initial words to skip 491 | skip_num = len(words_hyp) - n_mat 492 | 493 | # jump through the cache to the current position 494 | for i in range(skip_num): 495 | node = node[words_hyp[i]][0] 496 | 497 | assert len(words_hyp[skip_num:]) == n_mat 498 | 499 | # update cache with newly computed rows 500 | for word, row in zip(words_hyp[skip_num:], mat): 501 | if word not in node: 502 | node[word] = ({}, tuple(row)) 503 | self._cache_size += 1 504 | value = node[word] 505 | node = value[0] 506 | 507 | def _find_cache(self, words_hyp: List[TimedWord]) -> Tuple[int, List[List]]: 508 | """Find the already computed rows of the edit distance matrix in cache. 509 | 510 | Returns a partially computed edit distance matrix. 511 | 512 | :param words_hyp: Translation hypothesis. 513 | :return: Tuple (start position, dist). 514 | """ 515 | node = self._cache 516 | start_position = 0 517 | dist = [self._initial_row] 518 | for word in words_hyp: 519 | if word in node: 520 | start_position += 1 521 | node, row = node[word] 522 | dist.append(row) 523 | else: 524 | break 525 | 526 | return start_position, dist 527 | -------------------------------------------------------------------------------- /suber/metrics/sacrebleu_interface.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import List 3 | 4 | from sacrebleu.metrics import BLEU, TER, CHRF 5 | 6 | from suber.data_types import Segment 7 | from suber.utilities import segment_to_string, get_segment_to_string_opts_from_metric 8 | 9 | 10 | def calculate_sacrebleu_metric(hypothesis: List[Segment], reference: List[Segment], 11 | metric="BLEU", score_break_at_segment_end=True) -> float: 12 | 13 | assert len(hypothesis) == len(reference), ( 14 | "Number of hypothesis segments does not match reference, alignment step missing?") 15 | 16 | include_breaks, mask_words, metric = get_segment_to_string_opts_from_metric(metric) 17 | 18 | if metric == "BLEU": 19 | sacrebleu_metric = BLEU() 20 | elif metric == "TER": 21 | sacrebleu_metric = TER() 22 | elif metric == "chrF": 23 | sacrebleu_metric = CHRF() 24 | else: 25 | raise ValueError(f"Unsupported sacrebleu metric '{metric}'.") 26 | 27 | # Sacrebleu currently does not allow empty references, just skip empty reference segments as a workaround. 28 | if not all(segment.word_list for segment in reference): 29 | empty_reference_indices = [index for index, segment in enumerate(reference) if not segment.word_list] 30 | reference = [segment for index, segment in enumerate(reference) if index not in empty_reference_indices] 31 | hypothesis = [segment for index, segment in enumerate(hypothesis) if index not in empty_reference_indices] 32 | 33 | segment_to_string_ = functools.partial( 34 | segment_to_string, include_line_breaks=include_breaks, mask_all_words=mask_words, 35 | include_last_break=score_break_at_segment_end) 36 | 37 | hypothesis_strings = [segment_to_string_(segment) for segment in hypothesis] 38 | reference_strings = [[segment_to_string_(segment) for segment in reference]] # sacrebleu expects nested list 39 | 40 | if include_breaks: 41 | # BLEU tokenizer would split "" into "< eol >". 42 | hypothesis_strings = [string.replace("", "eol").replace("", "eob") for string in hypothesis_strings] 43 | reference_strings[0] = [ 44 | string.replace("", "eol").replace("", "eob") for string in reference_strings[0]] 45 | 46 | sacrebleu_score = sacrebleu_metric.corpus_score(hypotheses=hypothesis_strings, references=reference_strings) 47 | 48 | return round(sacrebleu_score.score, 3) 49 | -------------------------------------------------------------------------------- /suber/metrics/suber.py: -------------------------------------------------------------------------------- 1 | import string 2 | from typing import List 3 | 4 | from suber.data_types import Subtitle, TimedWord, LineBreak 5 | from suber.constants import END_OF_BLOCK_SYMBOL, END_OF_LINE_SYMBOL 6 | from suber.metrics import lib_ter 7 | from suber.metrics.suber_statistics import SubERStatisticsCollector 8 | 9 | from sacrebleu.tokenizers.tokenizer_ter import TercomTokenizer # only used for "SubER-cased" 10 | 11 | 12 | def calculate_SubER(hypothesis: List[Subtitle], reference: List[Subtitle], metric="SubER", 13 | statistics_collector: SubERStatisticsCollector = None) -> float: 14 | """ 15 | Main function to caculate the SubER score. It is computed on normalized text, which means case-insensitive and 16 | without taking punctuation into account, as we observed higher correlation with human judgements and post-edit 17 | effort in this setting. You can set the 'metric' parameter to "SubER-cased" to calculate a score on cased and 18 | punctuated text nevertheless. In this case punctuation will be treated as separate words by using a tokenizer. 19 | We use a modified version of 'lib_ter.py' from sacrebleu for the underlying TER implementation. We altered the 20 | algorithm by adding a time-overlap condition for word alignments and by disallowing word alignments between real 21 | words and break tokens. 22 | """ 23 | assert metric in ["SubER", "SubER-cased"] 24 | normalize = (metric == "SubER") 25 | 26 | total_num_edits = 0 27 | total_reference_length = 0 28 | 29 | for part in _get_independent_parts(hypothesis, reference): 30 | hypothesis_part, reference_part = part 31 | 32 | num_edits, reference_length = _calculate_num_edits_for_part( 33 | hypothesis_part, reference_part, normalize=normalize, statistics_collector=statistics_collector) 34 | 35 | total_num_edits += num_edits 36 | total_reference_length += reference_length 37 | 38 | if total_reference_length: 39 | SubER_score = (total_num_edits / total_reference_length) * 100 40 | 41 | elif not total_num_edits: 42 | SubER_score = 0.0 43 | else: 44 | SubER_score = 100.0 45 | 46 | return round(SubER_score, 3) 47 | 48 | 49 | def _calculate_num_edits_for_part(hypothesis_part: List[Subtitle], reference_part: List[Subtitle], normalize=True, 50 | statistics_collector: SubERStatisticsCollector = None): 51 | """ 52 | Returns number of edits (word or break edits and shifts) and the total number of reference tokens (words + breaks) 53 | for the current part. 54 | """ 55 | all_hypothesis_words = [word for segment in hypothesis_part for word in segment.word_list] 56 | all_reference_words = [word for segment in reference_part for word in segment.word_list] 57 | 58 | if normalize: 59 | # Although casing and punctuation are important aspects of subtitle quality, we observe higher correlation with 60 | # human post edit effort when normalizing the words. 61 | all_hypothesis_words = _normalize_words(all_hypothesis_words) 62 | all_reference_words = _normalize_words(all_reference_words) 63 | else: 64 | # When not normalizing punctuation symbols are kept. We treat them as separate tokens by splitting them off 65 | # the words using sacrebleu's TercomTokenizer. 66 | all_hypothesis_words = _tokenize_words(all_hypothesis_words) 67 | all_reference_words = _tokenize_words(all_reference_words) 68 | 69 | all_hypothesis_words = _add_breaks_as_words(all_hypothesis_words) 70 | all_reference_words = _add_breaks_as_words(all_reference_words) 71 | 72 | num_edits, reference_length = lib_ter.translation_edit_rate( 73 | all_hypothesis_words, all_reference_words, statistics_collector) 74 | 75 | assert reference_length == len(all_reference_words) 76 | 77 | return num_edits, reference_length 78 | 79 | 80 | def _add_breaks_as_words(words: List[TimedWord]) -> List[TimedWord]: 81 | """ 82 | Converts breaks from being an attribute of the previous Word to being a separate Word in the list. Needed because 83 | TER algorithm should handle breaks as normal tokens. 84 | """ 85 | output_words = [] 86 | for word in words: 87 | output_words.append( 88 | TimedWord( 89 | string=word.string, 90 | line_break=LineBreak.NONE, 91 | subtitle_start_time=word.subtitle_start_time, 92 | subtitle_end_time=word.subtitle_end_time, 93 | approximate_word_time=word.approximate_word_time)) 94 | 95 | if word.line_break is not LineBreak.NONE: 96 | output_words.append( 97 | TimedWord( 98 | string=END_OF_LINE_SYMBOL if word.line_break is LineBreak.END_OF_LINE else END_OF_BLOCK_SYMBOL, 99 | line_break=LineBreak.NONE, 100 | subtitle_start_time=word.subtitle_start_time, 101 | subtitle_end_time=word.subtitle_end_time, 102 | approximate_word_time=word.approximate_word_time)) 103 | 104 | return output_words 105 | 106 | 107 | remove_punctuation_table = str.maketrans('', '', string.punctuation) 108 | 109 | 110 | def _normalize_words(words: List[TimedWord]) -> List[TimedWord]: 111 | """ 112 | Lower-cases Words and removes punctuation. 113 | """ 114 | output_words = [] 115 | for word in words: 116 | normalized_string = word.string.lower() 117 | normalized_string_without_punctuation = normalized_string.translate(remove_punctuation_table) 118 | normalized_string_without_punctuation = normalized_string_without_punctuation.replace('…', '') 119 | 120 | if normalized_string_without_punctuation: # keep tokens that are purely punctuation 121 | normalized_string = normalized_string_without_punctuation 122 | 123 | output_words.append( 124 | TimedWord( 125 | string=normalized_string, 126 | line_break=word.line_break, 127 | subtitle_start_time=word.subtitle_start_time, 128 | subtitle_end_time=word.subtitle_end_time, 129 | approximate_word_time=word.approximate_word_time)) 130 | 131 | return output_words 132 | 133 | 134 | _tokenizer = None # created if needed in _tokenize_words(), has to be cached... 135 | 136 | 137 | def _tokenize_words(words: List[TimedWord]) -> List[TimedWord]: 138 | """ 139 | Not used for the main SubER metric, only for the "SubER-cased" variant. Applies sacrebleu's TercomTokenizer to all 140 | words in the input, which will create a new list of words containing punctuation symbols as separate elements. 141 | """ 142 | global _tokenizer 143 | if not _tokenizer: 144 | _tokenizer = TercomTokenizer(normalized=True, no_punct=False, case_sensitive=True) 145 | 146 | output_words = [] 147 | for word in words: 148 | tokenized_word_string = _tokenizer(word.string) 149 | tokens = tokenized_word_string.split() 150 | 151 | if len(tokens) == 1: 152 | assert tokenized_word_string == word.string 153 | output_words.append(word) 154 | continue 155 | 156 | for token_index, token in enumerate(tokens): 157 | output_words.append( 158 | TimedWord( 159 | string=token, 160 | # Keep line break after the original token, no line breaks within the original token. 161 | line_break=word.line_break if token_index == len(tokens) - 1 else LineBreak.NONE, 162 | subtitle_start_time=word.subtitle_start_time, 163 | subtitle_end_time=word.subtitle_end_time, 164 | approximate_word_time=word.approximate_word_time)) 165 | 166 | return output_words 167 | 168 | 169 | def _get_independent_parts(hypothesis: List[Subtitle], reference: List[Subtitle]): 170 | """ 171 | SubER by definition does not require parallel hypothesis-reference segments. We nevertheless split the subtitle file 172 | content into parts at positions in time where there is no subtitle in both hypothesis and reference. This makes 173 | calculation more efficient as Levenshtein distances are computed on shorter sequences, while not changing the 174 | metric score. 175 | 176 | Note, that in the worst case there are no such split points. In practice, this is unrealistic and subtitle files 177 | are usually limited to a few hours of speech, such that the current SubER calculation should be efficient enough. 178 | 179 | This function yields Tuple[List[Subtitle],List[Subtitle]] containing the hypothesis and reference subtitles for each 180 | part. 181 | """ 182 | hypothesis_part = [] 183 | reference_part = [] 184 | 185 | # We sweep the time axis from low to high and handle hypothesis and reference subtitles as soon as we reach them. 186 | hypothesis_subtitle_index = 0 # index of hypothesis subtitle to handle next 187 | reference_subtitle_index = 0 # index of reference subtitle to handle next 188 | latest_observed_time = - float('inf') # highest time observed so far (end time of a previously handled subtitle) 189 | 190 | while hypothesis_subtitle_index < len(hypothesis) or reference_subtitle_index < len(reference): 191 | if (hypothesis_subtitle_index < len(hypothesis) and ( 192 | reference_subtitle_index == len(reference) or 193 | hypothesis[hypothesis_subtitle_index].start_time < reference[reference_subtitle_index].start_time)): 194 | # We found the next subtitle on the time axis, it is from the hypothesis. 195 | 196 | if ((hypothesis_part or reference_part) 197 | and hypothesis[hypothesis_subtitle_index].start_time >= latest_observed_time): 198 | # The subtitle starts after the latest observed time, meaning there is a gap where no subtitle exists. 199 | # This concludes the current part, yield it. 200 | yield (hypothesis_part, reference_part) 201 | hypothesis_part, reference_part = [], [] 202 | 203 | hypothesis_part.append(hypothesis[hypothesis_subtitle_index]) 204 | latest_observed_time = max(latest_observed_time, hypothesis[hypothesis_subtitle_index].end_time) 205 | hypothesis_subtitle_index += 1 206 | 207 | else: # Next subtitle to handle is from the reference. 208 | if ((hypothesis_part or reference_part) 209 | and reference[reference_subtitle_index].start_time >= latest_observed_time): 210 | # The subtitle starts after the latest observed time, meaning there is a gap where no subtitle exists. 211 | # This concludes the current part, yield it. 212 | yield (hypothesis_part, reference_part) 213 | hypothesis_part, reference_part = [], [] 214 | 215 | reference_part.append(reference[reference_subtitle_index]) 216 | latest_observed_time = max(latest_observed_time, reference[reference_subtitle_index].end_time) 217 | reference_subtitle_index += 1 218 | 219 | assert hypothesis_subtitle_index == len(hypothesis) and reference_subtitle_index == len(reference) 220 | if hypothesis_part or reference_part: 221 | yield (hypothesis_part, reference_part) 222 | -------------------------------------------------------------------------------- /suber/metrics/suber_statistics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | from collections import OrderedDict 3 | 4 | from suber.data_types import Word 5 | from suber.constants import END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL 6 | 7 | 8 | class SubERStatisticsCollector: 9 | """ 10 | Collects number of different SubER edit operations necessary to turn the reference into the hypothesis. 11 | (The TER code and paper uses the hypothesis to reference direction, but calling a word occurring only in the 12 | hypothesis an insertion and a word missing in the hypothesis a deletion seems to be far more common.) 13 | """ 14 | 15 | def __init__(self): 16 | self._num_reference_words = 0 17 | self._num_reference_breaks = 0 18 | self._num_shifts = 0 19 | self._num_word_deletions = 0 20 | self._num_break_deletions = 0 21 | self._num_word_insertions = 0 22 | self._num_break_insertions = 0 23 | self._num_word_substitutions = 0 24 | self._num_break_substitutions = 0 25 | 26 | def add_data(self, trace: str, words_ref: List[Word], words_hyp_shifted: List[Word], num_shifts: int): 27 | """ 28 | Called inside lib_ter.translation_edit_rate(). 'trace' contains characters 'i', 'd', 's' and ' ' representing 29 | different edit operations. 30 | """ 31 | reference_position = -1 32 | hypothesis_position = -1 33 | 34 | for edit_operation in trace: 35 | if edit_operation != "i": 36 | reference_position += 1 37 | 38 | if edit_operation != "d": 39 | hypothesis_position += 1 40 | 41 | if edit_operation == "i": 42 | if words_hyp_shifted[hypothesis_position].string in [END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL]: 43 | self._num_break_insertions += 1 44 | else: 45 | self._num_word_insertions += 1 46 | else: 47 | is_break_edit = words_ref[reference_position].string in [END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL] 48 | 49 | if is_break_edit: 50 | self._num_reference_breaks += 1 51 | else: 52 | self._num_reference_words += 1 53 | 54 | if edit_operation == "d": 55 | if is_break_edit: 56 | self._num_break_deletions += 1 57 | else: 58 | self._num_word_deletions += 1 59 | 60 | elif edit_operation == "s": 61 | if is_break_edit: 62 | self._num_break_substitutions += 1 63 | else: 64 | self._num_word_substitutions += 1 65 | 66 | assert reference_position == len(words_ref) - 1 67 | assert hypothesis_position == len(words_hyp_shifted) - 1 68 | 69 | self._num_shifts += num_shifts 70 | 71 | def get_statistics(self) -> Dict[str, Any]: 72 | return OrderedDict( 73 | num_reference_words=self._num_reference_words, 74 | num_reference_breaks=self._num_reference_breaks, 75 | num_shifts=self._num_shifts, 76 | num_word_deletions=self._num_word_deletions, 77 | num_break_deletions=self._num_break_deletions, 78 | num_word_insertions=self._num_word_insertions, 79 | num_break_insertions=self._num_break_insertions, 80 | num_word_substitutions=self._num_word_substitutions, 81 | num_break_substitutions=self._num_break_substitutions, 82 | ) 83 | -------------------------------------------------------------------------------- /suber/sentence_segmentation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from suber.data_types import Segment 3 | 4 | 5 | sentence_final_punctuation = ['.', '?', '!', '!', '?', '。', "…"] 6 | quotation_marks = ["'", '"'] 7 | # all combinations of punctuation and quotes, i.e. '."', "?'" etc. 8 | quoted_sentence_final_punctuation = [punct + quote for punct in sentence_final_punctuation for quote in quotation_marks] 9 | ellipses = ["...", "…"] 10 | 11 | 12 | def resegment_based_on_punctuation(segments: List[Segment]) -> List[Segment]: 13 | resegmented_segments = [] 14 | 15 | all_words = [word for segment in segments for word in segment.word_list] 16 | 17 | word_list = [] 18 | previous_word = None 19 | for word in all_words: 20 | if not previous_word or _is_sentence_end(previous_word.string, word.string): 21 | if word_list: 22 | resegmented_segments.append(Segment(word_list=word_list)) 23 | word_list = [word] 24 | else: 25 | word_list.append(word) 26 | previous_word = word 27 | 28 | assert word_list 29 | resegmented_segments.append(Segment(word_list=word_list)) 30 | 31 | return resegmented_segments 32 | 33 | 34 | def _is_sentence_end(current_word: str, next_word: str = None): 35 | if not next_word: 36 | # No next word, force sentence end. 37 | return True 38 | 39 | assert current_word, "'current_word' must not be empty." 40 | 41 | is_sentence_final_punctuation_at_end = ( 42 | current_word[-1] in sentence_final_punctuation 43 | or (len(current_word) > 1 and current_word[-2:] in quoted_sentence_final_punctuation)) 44 | 45 | is_ellipsis_at_end = any(current_word.endswith(ellipsis) for ellipsis in ellipses) 46 | 47 | next_word_is_lower_cased = next_word[0].islower() 48 | next_word_is_lower_or_digit = next_word_is_lower_cased or next_word[0].isdigit() 49 | 50 | is_sentence_end = (is_sentence_final_punctuation_at_end and not next_word_is_lower_cased 51 | and not (is_ellipsis_at_end and next_word_is_lower_or_digit)) 52 | 53 | return is_sentence_end 54 | -------------------------------------------------------------------------------- /suber/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apptek/SubER/d174f422ab29edff23af58954520480ad56138c7/suber/tools/__init__.py -------------------------------------------------------------------------------- /suber/tools/align_hyp_to_ref.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | 5 | from suber.file_readers import read_input_file 6 | from suber.hyp_to_ref_alignment import levenshtein_align_hypothesis_to_reference 7 | from suber.hyp_to_ref_alignment import time_align_hypothesis_to_reference 8 | from suber.utilities import segment_to_string 9 | 10 | 11 | def parse_arguments(): 12 | parser = argparse.ArgumentParser(description="Re-segments the hypothesis file to match the reference. This can " 13 | "either be done via Levenshtein alignment, or using the subtitle " 14 | "timings, if available.") 15 | parser.add_argument("-H", "--hypothesis", required=True, help="The input file.") 16 | parser.add_argument("-R", "--reference", required=True, help="The reference file.") 17 | parser.add_argument("-o", "--aligned-hypothesis", required=True, 18 | help="The aligned hypothesis output file in plain format.") 19 | parser.add_argument("-f", "--hypothesis-format", default="SRT", help="Hypothesis file format, 'SRT' or 'plain'.") 20 | parser.add_argument("-F", "--reference-format", default="SRT", help="Reference file format, 'SRT' or 'plain'.") 21 | parser.add_argument("-m", "--method", default="levenshtein", 22 | help="The alignment method, either 'levenshtein' or 'time'. See the " 23 | "'suber.hyp_to_ref_alignment' module. 'time' only supported if both hypothesis and " 24 | "reference are given in SRT format.") 25 | 26 | return parser.parse_args() 27 | 28 | 29 | def main(): 30 | args = parse_arguments() 31 | 32 | if args.method == "time" and not args.hypothesis_format == "SRT" and args.reference_format == "SRT": 33 | raise ValueError("For time alignment, both hypothesis and reference have to be given in SRT format.") 34 | 35 | hypothesis_segments = read_input_file(args.hypothesis, file_format=args.hypothesis_format) 36 | reference_segments = read_input_file(args.reference, file_format=args.reference_format) 37 | 38 | if args.method == "levenshtein": 39 | aligned_hypothesis_segments = levenshtein_align_hypothesis_to_reference( 40 | hypothesis=hypothesis_segments, reference=reference_segments) 41 | elif args.method == "time": 42 | aligned_hypothesis_segments = time_align_hypothesis_to_reference( 43 | hypothesis=hypothesis_segments, reference=reference_segments) 44 | 45 | with open(args.aligned_hypothesis, "w", encoding="utf-8") as output_file_object: 46 | for segment in aligned_hypothesis_segments: 47 | output_file_object.write(segment_to_string(segment) + '\n') 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /suber/tools/srt_to_plain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | 5 | from suber.file_readers import SRTFileReader 6 | from suber.sentence_segmentation import resegment_based_on_punctuation 7 | from suber.utilities import segment_to_string 8 | 9 | 10 | def parse_arguments(): 11 | parser = argparse.ArgumentParser(description="Extracts plain text from SRT files.") 12 | parser.add_argument("-i", "--input-file", required=True, help="The input SRT file.") 13 | parser.add_argument("-o", "--output-file", required=True, help="The plain output file.") 14 | parser.add_argument("-s", "--sentence-segmentation", action="store_true", 15 | help="If enabled, output sentences instead of subtitle segments.") 16 | 17 | return parser.parse_args() 18 | 19 | 20 | def main(): 21 | args = parse_arguments() 22 | 23 | segments = SRTFileReader(args.input_file).read() 24 | 25 | if args.sentence_segmentation: 26 | segments = resegment_based_on_punctuation(segments) 27 | 28 | with open(args.output_file, "w", encoding="utf-8") as output_file_object: 29 | for segment in segments: 30 | output_file_object.write(segment_to_string(segment, include_line_breaks=True) + '\n') 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /suber/utilities.py: -------------------------------------------------------------------------------- 1 | from suber.data_types import LineBreak, Segment 2 | from suber.constants import END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL, MASK_SYMBOL 3 | 4 | 5 | def segment_to_string(segment: Segment, include_line_breaks=False, include_last_break=True, 6 | mask_all_words=False) -> str: 7 | if not include_line_breaks: 8 | assert not mask_all_words, ( 9 | "Refusing to mask all words when not printing breaks, output would contain only mask symbols.") 10 | return " ".join(word.string for word in segment.word_list) 11 | 12 | word_list_with_breaks = [] 13 | for word in segment.word_list: 14 | word_list_with_breaks.append(MASK_SYMBOL if mask_all_words else word.string) 15 | 16 | if word.line_break == LineBreak.END_OF_LINE: 17 | word_list_with_breaks.append(END_OF_LINE_SYMBOL) 18 | elif word.line_break == LineBreak.END_OF_BLOCK: 19 | word_list_with_breaks.append(END_OF_BLOCK_SYMBOL) 20 | 21 | if not include_last_break and word_list_with_breaks and word_list_with_breaks[-1] == END_OF_BLOCK_SYMBOL: 22 | word_list_with_breaks.pop() 23 | 24 | return " ".join(word_list_with_breaks) 25 | 26 | 27 | def get_segment_to_string_opts_from_metric(metric: str): 28 | include_breaks = False 29 | mask_words = False 30 | if metric.endswith("-br"): 31 | include_breaks = True 32 | mask_words = True 33 | metric = metric[:-len("-br")] 34 | elif metric.endswith("-seg"): 35 | include_breaks = True 36 | metric = metric[:-len("-seg")] 37 | 38 | return include_breaks, mask_words, metric 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apptek/SubER/d174f422ab29edff23af58954520480ad56138c7/tests/__init__.py -------------------------------------------------------------------------------- /tests/fuzz_levenshtein.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Checks that our altered Levenshtein implementation, which preserves behavior of Levenshtein v0.18.0 in terms of 4 | edit ops, returns the same edit distance as whichever (newer) Levenshtein is installed in your environment. 5 | Must match, ambiguity exists only in the alignment / edit ops, not the number of edit ops. 6 | """ 7 | 8 | import random 9 | import string 10 | 11 | from rapidfuzz.distance import Levenshtein 12 | 13 | from suber import lib_levenshtein 14 | 15 | 16 | for i in range(100000): 17 | N1 = random.randint(0, 40) 18 | N2 = random.randint(0, 40) 19 | 20 | s1 = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(N1)) 21 | s2 = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(N2)) 22 | 23 | distance_rapidfuzz = Levenshtein.distance(s1, s2) 24 | distance = lib_levenshtein.distance(s1, s2) 25 | assert distance == distance_rapidfuzz, (s1, s2, distance_levenshtein, distance_opcodes, i) 26 | num_editops = len(lib_levenshtein.editops(s1, s2)) 27 | assert distance == num_editops, (s1, s2, distance_levenshtein, distance_opcodes, i) 28 | -------------------------------------------------------------------------------- /tests/test_cer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.metrics.cer import calculate_character_error_rate 4 | from .utilities import create_temporary_file_and_read_it 5 | 6 | 7 | class CERTest(unittest.TestCase): 8 | 9 | def test_cer(self): 10 | reference_file_content = """ 11 | 1 12 | 00:00:00,000 --> 00:00:01,000 13 | This is a simple first frame. 14 | 15 | 2 16 | 00:00:01,000 --> 00:00:02,000 17 | This is another frame 18 | having two lines.""" 19 | 20 | hypothesis_file_content = """ 21 | 1 22 | 00:00:00,000 --> 00:00:01,000 23 | This is a simple first frame, 24 | 25 | 2 26 | 00:00:01,000 --> 00:00:02,000 27 | this is another 28 | frame having two lines.""" 29 | 30 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 31 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 32 | 33 | cer_score = calculate_character_error_rate( 34 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="CER") 35 | 36 | # Lower-case and without punctuation by default, so no edits. 37 | self.assertAlmostEqual(cer_score, 0.0) 38 | 39 | cer_cased_score = calculate_character_error_rate( 40 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="CER-cased") 41 | 42 | # 2 edits / 68 characters 43 | self.assertAlmostEqual(cer_cased_score, 2.941) 44 | 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/test_file_readers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.data_types import LineBreak 4 | from suber.file_readers.srt_file_reader import SRTFormatError 5 | from .utilities import create_temporary_file_and_read_it 6 | 7 | 8 | class PlainFileReaderTests(unittest.TestCase): 9 | def test_empty_file(self): 10 | segments = create_temporary_file_and_read_it("", file_format="plain") 11 | self.assertFalse(segments) 12 | 13 | def test_simple_file(self): 14 | file_content = """This is a line. 15 | These are two subtitle lines. """ 16 | 17 | segments = create_temporary_file_and_read_it(file_content, file_format="plain") 18 | 19 | self.assertEqual(len(segments), 2) 20 | 21 | first_segment_text = " ".join(word.string for word in segments[0].word_list) 22 | self.assertEqual(first_segment_text, "This is a line.") 23 | self.assertTrue(all(word.line_break == LineBreak.NONE for word in segments[0].word_list[:-1])) 24 | self.assertEqual(segments[0].word_list[-1].line_break, LineBreak.END_OF_BLOCK) 25 | 26 | second_segment_text = " ".join(word.string for word in segments[1].word_list) 27 | self.assertEqual(second_segment_text, "These are two subtitle lines.") 28 | self.assertEqual(segments[1].word_list[1].line_break, LineBreak.END_OF_LINE) 29 | self.assertEqual(segments[1].word_list[-1].line_break, LineBreak.END_OF_BLOCK) 30 | 31 | 32 | class SRTFileReaderTests(unittest.TestCase): 33 | def test_empty_file(self): 34 | subtitles = create_temporary_file_and_read_it("") 35 | self.assertFalse(subtitles) 36 | 37 | def test_simple_file(self): 38 | file_content = """ 39 | 1 40 | 00:00:00,000 --> 00:00:01,000 41 | This is a simple first frame. 42 | 43 | 2 44 | 00:00:01,000 --> 00:00:02,000 45 | This is another frame 46 | having two lines.""" 47 | 48 | subtitles = create_temporary_file_and_read_it(file_content) 49 | 50 | self.assertEqual(len(subtitles), 2) 51 | 52 | self.assertEqual(subtitles[0].index, 1) 53 | self.assertEqual(subtitles[1].index, 2) 54 | 55 | self.assertAlmostEqual(subtitles[0].start_time, 0.0) 56 | self.assertAlmostEqual(subtitles[0].end_time, 1.0) 57 | self.assertTrue(all(word.line_break == LineBreak.NONE for word in subtitles[0].word_list[:-1])) 58 | self.assertEqual(subtitles[0].word_list[-1].line_break, LineBreak.END_OF_BLOCK) 59 | 60 | self.assertAlmostEqual(subtitles[1].start_time, 1.0) 61 | self.assertAlmostEqual(subtitles[1].end_time, 2.0) 62 | self.assertEqual(subtitles[1].word_list[3].line_break, LineBreak.END_OF_LINE) 63 | self.assertEqual(subtitles[1].word_list[-1].line_break, LineBreak.END_OF_BLOCK) 64 | 65 | first_subtititle_text = " ".join(word.string for word in subtitles[0].word_list) 66 | self.assertEqual(first_subtititle_text, "This is a simple first frame.") 67 | 68 | second_subtititle_text = " ".join(word.string for word in subtitles[1].word_list) 69 | self.assertEqual(second_subtititle_text, "This is another frame having two lines.") 70 | 71 | def test_overlap_in_time(self): 72 | file_content = """ 73 | 1 74 | 00:00:01,000 --> 00:00:02,000 75 | This is a simple first frame. 76 | 77 | 2 78 | 00:00:00,000 --> 00:00:01,000 79 | This one is before the first one in time.""" 80 | 81 | with self.assertRaises(SRTFormatError): 82 | create_temporary_file_and_read_it(file_content) 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/test_hyp_to_ref_alignment.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.hyp_to_ref_alignment import time_align_hypothesis_to_reference 4 | from suber.hyp_to_ref_alignment import levenshtein_align_hypothesis_to_reference 5 | from .utilities import create_temporary_file_and_read_it 6 | 7 | 8 | class TimeAlignmentTests(unittest.TestCase): 9 | 10 | def test_full_overlap(self): 11 | reference_file_content = """ 12 | 1 13 | 00:00:00,000 --> 00:00:01,000 14 | This is a simple first frame. 15 | 16 | 2 17 | 00:00:01,000 --> 00:00:02,000 18 | This is another frame 19 | having two lines.""" 20 | 21 | hypothesis_file_content = """ 22 | 1 23 | 00:00:00,000 --> 00:00:01,000 24 | This is a simple first frame. 25 | 26 | 2 27 | 00:00:01,000 --> 00:00:01,500 28 | This is another frame 29 | 30 | 3 31 | 00:00:01,500 --> 00:00:02,000 32 | having two lines.""" 33 | 34 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 35 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 36 | 37 | hypothesis_subtitles = time_align_hypothesis_to_reference(hypothesis_subtitles, reference_subtitles) 38 | 39 | self.assertEqual(len(hypothesis_subtitles), 2) 40 | 41 | self.assertEqual(hypothesis_subtitles[0].index, 1) 42 | self.assertEqual(hypothesis_subtitles[1].index, 2) 43 | 44 | self.assertAlmostEqual(hypothesis_subtitles[0].start_time, 0.0) 45 | self.assertAlmostEqual(hypothesis_subtitles[0].end_time, 1.0) 46 | self.assertAlmostEqual(hypothesis_subtitles[1].start_time, 1.0) 47 | self.assertAlmostEqual(hypothesis_subtitles[1].end_time, 2.0) 48 | 49 | first_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[0].word_list) 50 | self.assertEqual(first_subtititle_text, "This is a simple first frame.") 51 | 52 | second_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[1].word_list) 53 | self.assertEqual(second_subtititle_text, "This is another frame having two lines.") 54 | 55 | def test_dropped_words(self): 56 | reference_file_content = """ 57 | 1 58 | 00:00:01,000 --> 00:00:02,000 59 | This is a simple first frame. 60 | 61 | 2 62 | 00:00:03,000 --> 00:00:04,000 63 | This is another frame 64 | having two lines.""" 65 | 66 | hypothesis_file_content = """ 67 | 1 68 | 00:00:00,000 --> 00:00:01,000 69 | Should be dropped. 70 | 71 | 2 72 | 00:00:01,000 --> 00:00:02,000 73 | This is a simple first frame. 74 | 75 | 3 76 | 00:00:02,000 --> 00:00:03,000 77 | Should be dropped. 78 | 79 | 4 80 | 00:00:03,000 --> 00:00:04,000 81 | This is another frame 82 | having two lines. 83 | 84 | 5 85 | 00:00:04,000 --> 00:00:05,000 86 | Should be dropped. 87 | """ 88 | 89 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 90 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 91 | 92 | hypothesis_subtitles = time_align_hypothesis_to_reference(hypothesis_subtitles, reference_subtitles) 93 | 94 | self.assertEqual(len(hypothesis_subtitles), 2) 95 | 96 | first_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[0].word_list) 97 | self.assertEqual(first_subtititle_text, "This is a simple first frame.") 98 | 99 | second_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[1].word_list) 100 | self.assertEqual(second_subtititle_text, "This is another frame having two lines.") 101 | 102 | def test_partial_overlap(self): 103 | reference_file_content = """ 104 | 1 105 | 00:00:01,000 --> 00:00:02,000 106 | This is a simple first frame. 107 | 108 | 2 109 | 00:00:03,000 --> 00:00:04,000 110 | This is another frame 111 | having two lines.""" 112 | 113 | hypothesis_file_content = """ 114 | 1 115 | 00:00:00,000 --> 00:00:02,000 116 | This is a simple first frame. 117 | 118 | 2 119 | 00:00:02,500 --> 00:00:03,500 120 | This is another frame 121 | 122 | 3 123 | 00:00:03,500 --> 00:00:04,500 124 | having two lines.""" 125 | 126 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 127 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 128 | 129 | hypothesis_subtitles = time_align_hypothesis_to_reference(hypothesis_subtitles, reference_subtitles) 130 | 131 | self.assertEqual(len(hypothesis_subtitles), 2) 132 | 133 | first_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[0].word_list) 134 | self.assertEqual(first_subtititle_text, "simple first frame.") 135 | 136 | second_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[1].word_list) 137 | self.assertEqual(second_subtititle_text, "another frame having") 138 | 139 | def test_gap_in_overlap(self): 140 | reference_file_content = """ 141 | 1 142 | 00:00:00,000 --> 00:00:01,000 143 | This is a simple first frame. 144 | 145 | 2 146 | 00:00:02,000 --> 00:00:03,000 147 | This is another frame 148 | having two lines.""" 149 | 150 | hypothesis_file_content = """ 151 | 1 152 | 00:00:00,000 --> 00:00:03,000 153 | This is a simple first frame. 154 | This is another frame 155 | having two lines.""" 156 | 157 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 158 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 159 | 160 | hypothesis_subtitles = time_align_hypothesis_to_reference(hypothesis_subtitles, reference_subtitles) 161 | 162 | self.assertEqual(len(hypothesis_subtitles), 2) 163 | 164 | first_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[0].word_list) 165 | self.assertEqual(first_subtititle_text, "This is a simple") 166 | 167 | second_subtititle_text = " ".join(word.string for word in hypothesis_subtitles[1].word_list) 168 | self.assertEqual(second_subtititle_text, "frame having two lines.") 169 | 170 | 171 | class LevenshteinAlignmentTests(unittest.TestCase): 172 | def test_identical_files(self): 173 | file_content = """This is a line. 174 | That is another one.""" 175 | 176 | segments = create_temporary_file_and_read_it(file_content, file_format="plain") 177 | 178 | aligned_segments = levenshtein_align_hypothesis_to_reference(hypothesis=segments, reference=segments) 179 | 180 | self.assertEqual(segments, aligned_segments) 181 | 182 | def test_identical_words(self): 183 | reference_file_content = """This is a line. 184 | That is another one. 185 | And a third segment.""" 186 | 187 | hypothesis_file_content = """This is a line. That 188 | is another 189 | one. And a third segment.""" 190 | 191 | reference_segments = create_temporary_file_and_read_it(reference_file_content, file_format="plain") 192 | hypothesis_segments = create_temporary_file_and_read_it(hypothesis_file_content, file_format="plain") 193 | 194 | hypothesis_segments = levenshtein_align_hypothesis_to_reference(hypothesis_segments, reference_segments) 195 | 196 | self.assertEqual(len(hypothesis_segments), 3) 197 | 198 | first_segment_text = " ".join(word.string for word in hypothesis_segments[0].word_list) 199 | self.assertEqual(first_segment_text, "This is a line.") 200 | 201 | second_segment_text = " ".join(word.string for word in hypothesis_segments[1].word_list) 202 | self.assertEqual(second_segment_text, "That is another one.") 203 | 204 | third_segment_text = " ".join(word.string for word in hypothesis_segments[2].word_list) 205 | self.assertEqual(third_segment_text, "And a third segment.") 206 | 207 | def test_with_edits(self): 208 | reference_file_content = """This is a line. 209 | That is another one. 210 | And a third segment.""" 211 | 212 | hypothesis_file_content = """This is a lines. That this 213 | is another one. And third segment.""" 214 | 215 | reference_segments = create_temporary_file_and_read_it(reference_file_content, file_format="plain") 216 | hypothesis_segments = create_temporary_file_and_read_it(hypothesis_file_content, file_format="plain") 217 | 218 | hypothesis_segments = levenshtein_align_hypothesis_to_reference(hypothesis_segments, reference_segments) 219 | 220 | self.assertEqual(len(hypothesis_segments), 3) 221 | 222 | first_segment_text = " ".join(word.string for word in hypothesis_segments[0].word_list) 223 | self.assertEqual(first_segment_text, "This is a lines.") 224 | 225 | second_segment_text = " ".join(word.string for word in hypothesis_segments[1].word_list) 226 | self.assertEqual(second_segment_text, "That this is another one.") 227 | 228 | third_segment_text = " ".join(word.string for word in hypothesis_segments[2].word_list) 229 | self.assertEqual(third_segment_text, "And third segment.") 230 | 231 | def test_with_edits_at_segment_boundary(self): 232 | reference_file_content = """This is a line. 233 | That is another one. 234 | And a third segment.""" 235 | 236 | hypothesis_file_content = """Some words at the start. This is a line. Where do these 237 | words belong to? another one. 238 | And a third 239 | segment. Some words in the end.""" 240 | 241 | reference_segments = create_temporary_file_and_read_it(reference_file_content, file_format="plain") 242 | hypothesis_segments = create_temporary_file_and_read_it(hypothesis_file_content, file_format="plain") 243 | 244 | hypothesis_segments = levenshtein_align_hypothesis_to_reference(hypothesis_segments, reference_segments) 245 | 246 | self.assertEqual(len(hypothesis_segments), 3) 247 | 248 | first_segment_text = " ".join(word.string for word in hypothesis_segments[0].word_list) 249 | self.assertEqual(first_segment_text, "Some words at the start. This is a line. Where do these words") 250 | 251 | second_segment_text = " ".join(word.string for word in hypothesis_segments[1].word_list) 252 | self.assertEqual(second_segment_text, "belong to? another one.") 253 | 254 | third_segment_text = " ".join(word.string for word in hypothesis_segments[2].word_list) 255 | self.assertEqual(third_segment_text, "And a third segment. Some words in the end.") 256 | 257 | def test_with_very_few_words(self): 258 | reference_file_content = """This is a line. 259 | That is another one. 260 | And a third segment.""" 261 | 262 | hypothesis_file_content = """Very few words.""" 263 | 264 | reference_segments = create_temporary_file_and_read_it(reference_file_content, file_format="plain") 265 | hypothesis_segments = create_temporary_file_and_read_it(hypothesis_file_content, file_format="plain") 266 | 267 | hypothesis_segments = levenshtein_align_hypothesis_to_reference(hypothesis_segments, reference_segments) 268 | 269 | self.assertEqual(len(hypothesis_segments), 3) 270 | 271 | 272 | if __name__ == '__main__': 273 | unittest.main() 274 | -------------------------------------------------------------------------------- /tests/test_jiwer_interface.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.metrics.jiwer_interface import calculate_word_error_rate 4 | from .utilities import create_temporary_file_and_read_it 5 | 6 | 7 | class JiWERInterfaceTest(unittest.TestCase): 8 | 9 | def test_wer(self): 10 | reference_file_content = """ 11 | 1 12 | 00:00:00,000 --> 00:00:01,000 13 | This is a simple first frame. 14 | 15 | 2 16 | 00:00:01,000 --> 00:00:02,000 17 | This is another frame 18 | having two lines.""" 19 | 20 | hypothesis_file_content = """ 21 | 1 22 | 00:00:00,000 --> 00:00:01,000 23 | This is a simple first frame, 24 | 25 | 2 26 | 00:00:01,000 --> 00:00:02,000 27 | this is another 28 | frame having two lines.""" 29 | 30 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 31 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 32 | 33 | wer_score = calculate_word_error_rate( 34 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER") 35 | 36 | self.assertAlmostEqual(wer_score, 0.0) 37 | 38 | wer_cased_score = calculate_word_error_rate( 39 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-cased") 40 | 41 | # 2 substitutions (casing and punctuation error) / 15 tokenized words 42 | self.assertAlmostEqual(wer_cased_score, 13.333) 43 | 44 | wer_seg_score = calculate_word_error_rate( 45 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg") 46 | 47 | # (1 break deletion + 1 break insertion) / (13 words + 3 breaks) 48 | self.assertAlmostEqual(wer_seg_score, 12.5) 49 | 50 | wer_seg_score = calculate_word_error_rate( 51 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", 52 | score_break_at_segment_end=False) 53 | 54 | # (1 break deletion + 1 break insertion) / (13 words + 1 breaks) 55 | self.assertAlmostEqual(wer_seg_score, 14.286) 56 | 57 | 58 | if __name__ == '__main__': 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /tests/test_length_ratio.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.metrics.length_ratio import calculate_length_ratio 4 | from .utilities import create_temporary_file_and_read_it 5 | 6 | 7 | class LengthRatioTest(unittest.TestCase): 8 | def setUp(self): 9 | # Punctuation marks should count as separate tokens. 10 | reference_file_content = """ 11 | 1 12 | 00:00:00,000 --> 00:00:01,000 13 | One two three. 14 | 15 | 2 16 | 00:00:01,000 --> 00:00:02,000 17 | Five six 18 | seven eight?""" 19 | 20 | hypothesis_file_content = """ 21 | 1 22 | 00:00:00,000 --> 00:00:01,000 23 | One two. 24 | 25 | 2 26 | 00:00:01,000 --> 00:00:01,500 27 | Four five 28 | 29 | 3 30 | 00:00:01,500 --> 00:00:02,000 31 | six?""" 32 | 33 | self._reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 34 | self._hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 35 | 36 | def test_length_ratio(self): 37 | length_ratio = calculate_length_ratio( 38 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles) 39 | 40 | self.assertAlmostEqual(length_ratio, 7 / 9 * 100, places=3) 41 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | import subprocess 4 | import json 5 | 6 | from typing import List 7 | from contextlib import ExitStack 8 | 9 | 10 | class MainFunctionTests(unittest.TestCase): 11 | 12 | def _run_main(self, hypothesis_files_contents: List[str], reference_files_contents: List[str]): 13 | """ 14 | Creates temporary hypothesis and reference files, runs the SubER tool and returns the metric scores. 15 | """ 16 | 17 | with ExitStack() as stack: 18 | def write_files(files_contents): 19 | files = [stack.enter_context(tempfile.NamedTemporaryFile(mode="w", suffix=".srt")) 20 | for _ in files_contents] 21 | 22 | for i, file_content in enumerate(files_contents): 23 | files[i].write(file_content) 24 | files[i].flush() 25 | 26 | file_names = " ".join(file.name for file in files) 27 | return file_names 28 | 29 | hypothesis_file_names = write_files(hypothesis_files_contents) 30 | reference_file_names = write_files(reference_files_contents) 31 | 32 | # Check all metrics, including hyp-to-ref-alignment. 33 | completed_process = subprocess.run( 34 | f"python3 -m suber " 35 | f"--hypothesis {hypothesis_file_names} --reference {reference_file_names} " 36 | f"--metrics SubER WER CER BLEU TER chrF TER-br WER-seg BLEU-seg AS-BLEU t-BLEU".split(), 37 | check=True, stdout=subprocess.PIPE) 38 | 39 | metric_scores = json.loads(completed_process.stdout.decode("utf-8")) 40 | 41 | return metric_scores 42 | 43 | def test_main_function(self): 44 | file_content = """ 45 | 1 46 | 00:00:00,000 --> 00:00:01,000 47 | This is a simple first frame. 48 | 49 | 2 50 | 00:00:01,000 --> 00:00:02,000 51 | This is another frame 52 | having two lines.""" 53 | 54 | metric_scores = self._run_main( 55 | hypothesis_files_contents=[file_content], reference_files_contents=[file_content]) 56 | 57 | # Just check that it runs through. 58 | self.assertTrue(metric_scores) 59 | 60 | def test_multiple_files(self): 61 | """ 62 | We support multiple input files, see 'suber.concat_input_files'. 63 | """ 64 | hypothesis_file1_content = """ 65 | 1 66 | 00:00:00,000 --> 00:00:00,800 67 | This is a first frame.""" 68 | 69 | hypothesis_file2_content = """ 70 | 2 71 | 00:00:00,400 --> 00:00:01,200 72 | This is another frame which should have two lines.""" 73 | 74 | reference_file1_content = """ 75 | 1 76 | 00:00:00,000 --> 00:00:01,000 77 | This is a simple first frame.""" 78 | 79 | reference_file2_content = """ 80 | 2 81 | 00:00:00,000 --> 00:00:01,000 82 | This is another frame 83 | having two lines.""" 84 | 85 | metric_scores_split_files = self._run_main( 86 | hypothesis_files_contents=[hypothesis_file1_content, hypothesis_file2_content], 87 | reference_files_contents=[reference_file1_content, reference_file2_content]) 88 | 89 | # Note: also concatenated in time, second subtitle is shifted by duration of first. 90 | concatenated_hypothesis_file_content = """ 91 | 1 92 | 00:00:00,000 --> 00:00:00,800 93 | This is a first frame. 94 | 95 | 2 96 | 00:00:01,400 --> 00:00:02,200 97 | This is another frame which should have two lines.""" 98 | 99 | concatenated_reference_file_content = """ 100 | 1 101 | 00:00:00,000 --> 00:00:01,000 102 | This is a simple first frame. 103 | 104 | 2 105 | 00:00:01,000 --> 00:00:02,000 106 | This is another frame 107 | having two lines.""" 108 | 109 | metric_scores_concatenated_files = self._run_main( 110 | hypothesis_files_contents=[concatenated_hypothesis_file_content], 111 | reference_files_contents=[concatenated_reference_file_content]) 112 | 113 | # We expect manual concatenation and giving multiple files to be equivalent. 114 | self.assertEqual(metric_scores_split_files, metric_scores_concatenated_files) 115 | 116 | 117 | if __name__ == '__main__': 118 | unittest.main() 119 | -------------------------------------------------------------------------------- /tests/test_sacrebleu_interface.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.metrics.sacrebleu_interface import calculate_sacrebleu_metric 4 | from .utilities import create_temporary_file_and_read_it 5 | 6 | 7 | class SacreBleuInterfaceTest(unittest.TestCase): 8 | def setUp(self): 9 | reference_file_content = """ 10 | 1 11 | 00:00:00,000 --> 00:00:01,000 12 | This is a simple first frame. 13 | 14 | 2 15 | 00:00:01,000 --> 00:00:02,000 16 | This is another frame 17 | having two lines.""" 18 | 19 | hypothesis_file_content = """ 20 | 1 21 | 00:00:00,000 --> 00:00:01,000 22 | This is a simple first frame. 23 | 24 | 2 25 | 00:00:01,000 --> 00:00:02,000 26 | This is another 27 | frame having two lines.""" 28 | 29 | self._reference_subtitles = create_temporary_file_and_read_it(reference_file_content) 30 | self._hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) 31 | 32 | def test_bleu(self): 33 | bleu_score = calculate_sacrebleu_metric( 34 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="BLEU") 35 | 36 | self.assertAlmostEqual(bleu_score, 100.0) 37 | 38 | bleu_seg_score = calculate_sacrebleu_metric( 39 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="BLEU-seg") 40 | 41 | self.assertAlmostEqual(bleu_seg_score, 76.279) 42 | 43 | bleu_seg_score = calculate_sacrebleu_metric( 44 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="BLEU-seg", 45 | score_break_at_segment_end=False) 46 | 47 | self.assertAlmostEqual(bleu_seg_score, 71.538) 48 | 49 | def test_TER(self): 50 | ter_score = calculate_sacrebleu_metric( 51 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="TER") 52 | 53 | self.assertAlmostEqual(ter_score, 0.0) 54 | 55 | ter_seg_score = calculate_sacrebleu_metric( 56 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="TER-seg") 57 | 58 | # 1 break shift / (13 words + 3 breaks) 59 | self.assertAlmostEqual(ter_seg_score, 6.25) 60 | 61 | ter_seg_score = calculate_sacrebleu_metric( 62 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="TER-seg", 63 | score_break_at_segment_end=False) 64 | 65 | # 1 break shift / (13 words + 1 breaks) 66 | self.assertAlmostEqual(ter_seg_score, 7.143) 67 | 68 | def test_TER_br(self): 69 | reference_file_content = "This is one sentence with a line break and a frame break. " 70 | hypothesis_file_content = "This is a sentence with a line break and a block break. " 71 | 72 | reference_subtitles = create_temporary_file_and_read_it(reference_file_content, file_format="plain") 73 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content, file_format="plain") 74 | 75 | ter_br_score = calculate_sacrebleu_metric( 76 | hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="TER-br") 77 | 78 | # 12 real words, 3 line breaks, so reference length is 15. Edit operations should be 2 shifts and 1 insertion 79 | # of break symbols. 80 | expected_ter_br_score = round(3 / (12 + 3) * 100, 3) 81 | self.assertAlmostEqual(ter_br_score, expected_ter_br_score) 82 | 83 | def test_chrF(self): 84 | chrF_score = calculate_sacrebleu_metric( 85 | hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="chrF") 86 | 87 | self.assertAlmostEqual(chrF_score, 100.0) 88 | 89 | 90 | if __name__ == '__main__': 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /tests/test_sentence_segmentation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.sentence_segmentation import resegment_based_on_punctuation 4 | from .utilities import create_temporary_file_and_read_it 5 | 6 | 7 | class SentenceSegmentationTests(unittest.TestCase): 8 | 9 | def test_sentence_segmentation(self): 10 | file_content = """This is a first sentence. This is 11 | another one… That is a question? 'That is a quoted 12 | exclamation!' Ellipsis... that continues. even further. 1. sentence 13 | starting with a number... 2. one.""" 14 | 15 | segments = create_temporary_file_and_read_it(file_content, file_format="plain") 16 | 17 | sentences = resegment_based_on_punctuation(segments) 18 | 19 | self.assertEqual(len(sentences), 6) 20 | 21 | expected_sentences = [ 22 | "This is a first sentence.", 23 | "This is another one…", 24 | "That is a question?", 25 | "'That is a quoted exclamation!'", 26 | "Ellipsis... that continues. even further.", 27 | "1. sentence starting with a number... 2. one."] 28 | 29 | for index in range(len(sentences)): 30 | sentence_text = " ".join(word.string for word in sentences[index].word_list) 31 | self.assertEqual(sentence_text, expected_sentences[index]) 32 | 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /tests/test_suber_metric.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from suber.data_types import Subtitle 4 | from suber.metrics.suber import calculate_SubER, _get_independent_parts 5 | from .utilities import create_temporary_file_and_read_it 6 | 7 | 8 | class SubERMetricTests(unittest.TestCase): 9 | def setUp(self): 10 | self._reference1 = """ 11 | 1 12 | 0:00:01.000 --> 0:00:02.000 13 | This is a subtitle.""" 14 | 15 | self._reference2 = """ 16 | 1 17 | 0:00:01.000 --> 0:00:02.000 18 | This is a subtitle. 19 | 20 | 2 21 | 0:00:03.000 --> 0:00:04.000 22 | And another one!""" 23 | 24 | def _run_test(self, hypothesis, reference, expected_score): 25 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) 26 | reference_subtitles = create_temporary_file_and_read_it(reference) 27 | 28 | SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles) 29 | 30 | self.assertAlmostEqual(SubER_score, expected_score) 31 | 32 | def test_empty(self): 33 | self._run_test(hypothesis="", reference="", expected_score=0.0) 34 | self._run_test(hypothesis="", reference=self._reference1, expected_score=100.0) 35 | self._run_test(hypothesis=self._reference1, reference="", expected_score=100.0) 36 | 37 | def test_identical(self): 38 | self._run_test(hypothesis=self._reference1, reference=self._reference1, expected_score=0.0) 39 | self._run_test(hypothesis=self._reference2, reference=self._reference2, expected_score=0.0) 40 | 41 | def test_one_shift(self): 42 | hypothesis = """ 43 | 1 44 | 0:00:01.000 --> 0:00:02.000 45 | a subtitle. This is""" 46 | # 1 shift / (4 words + 1 break) 47 | self._run_test(hypothesis, self._reference1, expected_score=20.0) 48 | 49 | def test_no_overlap(self): 50 | hypothesis = """ 51 | 1 52 | 0:00:00.000 --> 0:00:01.000 53 | This is a subtitle.""" 54 | # All words + breaks count as deletion + insertion. 55 | self._run_test(hypothesis, self._reference1, expected_score=200.0) 56 | 57 | def test_with_overlap(self): 58 | hypothesis = """ 59 | 1 60 | 0:00:00.500 --> 0:00:01.500 61 | This is a subtitle.""" 62 | self._run_test(hypothesis, self._reference1, expected_score=0.0) 63 | 64 | def test_split_subtitle(self): 65 | hypothesis = """ 66 | 1 67 | 0:00:01.000 --> 0:00:01.500 68 | This is 69 | 70 | 2 71 | 0:00:01.500 --> 0:00:02.000 72 | a subtitle.""" 73 | # 1 break insertion / (4 words + 1 break) 74 | self._run_test(hypothesis, self._reference1, expected_score=20.0) 75 | 76 | def test_split_subtitle_with_shift(self): 77 | hypothesis = """ 78 | 1 79 | 0:00:00.100 --> 0:00:01.500 80 | This is 81 | 82 | 2 83 | 0:00:01.500 --> 0:00:02.000 84 | subtitle. a""" 85 | # (1 shift + 1 break insertion) / (4 words + 1 break) 86 | self._run_test(hypothesis, self._reference1, expected_score=40.0) 87 | 88 | def test_split_subtitle_no_overlap(self): 89 | hypothesis = """ 90 | 1 91 | 0:00:00.000 --> 0:00:00.500 92 | This is 93 | 94 | 2 95 | 0:00:02.500 --> 0:00:03.000 96 | a subtitle.""" 97 | # (4 word insertions + 2 break insertions + 4 word deletions + 1 break deletion) / (4 words + 1 break) 98 | self._run_test(hypothesis, self._reference1, expected_score=220.0) 99 | 100 | def test_split_subtitle_one_overlap(self): 101 | hypothesis = """ 102 | 1 103 | 0:00:00.000 --> 0:00:00.500 104 | This is 105 | 106 | 2 107 | 0:00:01.500 --> 0:00:02.000 108 | a subtitle.""" 109 | # (2 word insertions + 1 break insertions + 2 word deletions) / (4 words + 1 break) 110 | self._run_test(hypothesis, self._reference1, expected_score=100.0) 111 | 112 | def test_merged_subtitle(self): 113 | hypothesis = """ 114 | 1 115 | 0:00:01.000 --> 0:00:04.000 116 | This is a subtitle. 117 | And another one!""" 118 | # 1 break substitution / (7 words + 2 breaks) 119 | self._run_test(hypothesis, self._reference2, expected_score=11.111) 120 | 121 | def test_merged_subtitle_with_shift(self): 122 | hypothesis = """ 123 | 1 124 | 0:00:01.000 --> 0:00:04.000 125 | This is a another one! subtitle. 126 | And""" 127 | # (1 shift + 1 break substitution) / (7 words + 2 breaks) 128 | self._run_test(hypothesis, self._reference2, expected_score=22.222) 129 | 130 | def test_split_into_three(self): 131 | hypothesis = """ 132 | 1 133 | 0:00:01.000 --> 0:00:01.500 134 | This is a 135 | 136 | 2 137 | 0:00:01.500 --> 0:00:03.500 138 | subtitle. 139 | And 140 | 141 | 2 142 | 0:00:03.500 --> 0:00:04.000 143 | another one!""" 144 | # (2 break insertions + 1 break substitution) / (7 words + 2 breaks) 145 | self._run_test(hypothesis, self._reference2, expected_score=33.333) 146 | 147 | def test_split_into_three_with_one_shift(self): 148 | hypothesis = """ 149 | 1 150 | 0:00:01.000 --> 0:00:01.500 151 | This is a 152 | 153 | 2 154 | 0:00:01.500 --> 0:00:03.500 155 | another 156 | subtitle. 157 | 158 | 2 159 | 0:00:03.500 --> 0:00:04.000 160 | And one!""" 161 | # (1 shift + 2 break insertions) / (7 words + 2 breaks) 162 | self._run_test(hypothesis, self._reference2, expected_score=33.333) 163 | 164 | 165 | class SubERCasedMetricTests(unittest.TestCase): 166 | def test_SubER_cased(self): 167 | reference = """ 168 | 1 169 | 0:00:01.000 --> 0:00:02.000 170 | This is a subtitle. 171 | 172 | 2 173 | 0:00:03.000 --> 0:00:04.000 174 | And another one!""" 175 | 176 | hypothesis = """ 177 | 1 178 | 0:00:01.000 --> 0:00:01.500 179 | This is a 180 | 181 | 2 182 | 0:00:01.500 --> 0:00:03.500 183 | another 184 | subtitle, 185 | 186 | 2 187 | 0:00:03.500 --> 0:00:04.000 188 | and one!""" 189 | 190 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) 191 | reference_subtitles = create_temporary_file_and_read_it(reference) 192 | 193 | SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles, metric="SubER-cased") 194 | 195 | # After tokenization there should be 9 reference words + 2 reference break tokens. 196 | # 1 shift and 2 break deletions as above for SubER, plus 2 substitutions: "," -> "."; "and" -> "And" 197 | self.assertAlmostEqual(SubER_score, 45.455) 198 | 199 | 200 | class SubERHelperFunctionTests(unittest.TestCase): 201 | 202 | def test_get_independent_parts_empty_input(self): 203 | parts = list(_get_independent_parts(hypothesis=[], reference=[])) 204 | self.assertFalse(parts) 205 | 206 | def test_get_independent_parts_only_hypothesis(self): 207 | hypothesis = [ 208 | Subtitle(word_list=[], index=1, start_time=0, end_time=1), 209 | Subtitle(word_list=[], index=2, start_time=1, end_time=2), 210 | Subtitle(word_list=[], index=3, start_time=3, end_time=4)] 211 | 212 | parts = list(_get_independent_parts(hypothesis=hypothesis, reference=[])) 213 | self.assertEqual(len(parts), 3) 214 | self.assertEqual(parts[0], ([hypothesis[0]], [])) 215 | self.assertEqual(parts[1], ([hypothesis[1]], [])) 216 | self.assertEqual(parts[2], ([hypothesis[2]], [])) 217 | 218 | def test_get_independent_parts_only_reference(self): 219 | reference = [ 220 | Subtitle(word_list=[], index=1, start_time=0, end_time=1), 221 | Subtitle(word_list=[], index=2, start_time=1, end_time=2), 222 | Subtitle(word_list=[], index=3, start_time=3, end_time=4)] 223 | 224 | parts = list(_get_independent_parts(hypothesis=[], reference=reference)) 225 | self.assertEqual(len(parts), 3) 226 | self.assertEqual(parts[0], ([], [reference[0]])) 227 | self.assertEqual(parts[1], ([], [reference[1]])) 228 | self.assertEqual(parts[2], ([], [reference[2]])) 229 | 230 | def test_get_independent_parts_all_overlaps(self): 231 | hypothesis = [ 232 | Subtitle(word_list=[], index=1, start_time=0, end_time=1), 233 | Subtitle(word_list=[], index=2, start_time=1, end_time=2), 234 | Subtitle(word_list=[], index=3, start_time=3, end_time=4)] 235 | 236 | parts = list(_get_independent_parts(hypothesis=hypothesis, reference=hypothesis)) 237 | self.assertEqual(len(parts), 3) 238 | self.assertEqual(parts[0], ([hypothesis[0]], [hypothesis[0]])) 239 | self.assertEqual(parts[1], ([hypothesis[1]], [hypothesis[1]])) 240 | self.assertEqual(parts[2], ([hypothesis[2]], [hypothesis[2]])) 241 | 242 | def test_get_independent_parts_overlap_with_one_big(self): 243 | hypothesis = [ 244 | Subtitle(word_list=[], index=1, start_time=0.25, end_time=1), 245 | Subtitle(word_list=[], index=2, start_time=1, end_time=2), 246 | Subtitle(word_list=[], index=3, start_time=3, end_time=4)] 247 | 248 | reference = [ 249 | Subtitle(word_list=[], index=1, start_time=0.5, end_time=3.5)] 250 | 251 | parts = list(_get_independent_parts(hypothesis=hypothesis, reference=reference)) 252 | self.assertEqual(len(parts), 1) 253 | self.assertEqual(parts[0], (hypothesis, reference)) 254 | 255 | reference = [ 256 | Subtitle(word_list=[], index=1, start_time=0, end_time=4.5)] 257 | 258 | parts = list(_get_independent_parts(hypothesis=hypothesis, reference=reference)) 259 | self.assertEqual(len(parts), 1) 260 | self.assertEqual(parts[0], (hypothesis, reference)) 261 | 262 | def test_get_independent_parts(self): 263 | hypothesis = [ 264 | Subtitle(word_list=[], index=1, start_time=0, end_time=0.25), 265 | Subtitle(word_list=[], index=2, start_time=0.25, end_time=0.5), 266 | Subtitle(word_list=[], index=3, start_time=0.5, end_time=1), 267 | Subtitle(word_list=[], index=4, start_time=1, end_time=1.5), 268 | Subtitle(word_list=[], index=5, start_time=1.75, end_time=2), 269 | Subtitle(word_list=[], index=6, start_time=4.1, end_time=4.9), 270 | Subtitle(word_list=[], index=7, start_time=5.1, end_time=6), 271 | Subtitle(word_list=[], index=8, start_time=6, end_time=7), 272 | Subtitle(word_list=[], index=9, start_time=7, end_time=8)] 273 | 274 | reference = [ 275 | Subtitle(word_list=[], index=1, start_time=0.75, end_time=1.1), 276 | Subtitle(word_list=[], index=2, start_time=1.4, end_time=2.2), 277 | Subtitle(word_list=[], index=3, start_time=3, end_time=3.5), 278 | Subtitle(word_list=[], index=4, start_time=3.5, end_time=4), 279 | Subtitle(word_list=[], index=5, start_time=4, end_time=5), 280 | Subtitle(word_list=[], index=6, start_time=6, end_time=6.5), 281 | Subtitle(word_list=[], index=7, start_time=6.5, end_time=7.5), 282 | Subtitle(word_list=[], index=8, start_time=8, end_time=9)] 283 | 284 | parts = list(_get_independent_parts(hypothesis=hypothesis, reference=reference)) 285 | self.assertEqual(len(parts), 9) 286 | self.assertEqual(parts[0], (hypothesis[0:1], [])) 287 | self.assertEqual(parts[1], (hypothesis[1:2], [])) 288 | self.assertEqual(parts[2], (hypothesis[2:5], reference[:2])) 289 | self.assertEqual(parts[3], ([], reference[2:3])) 290 | self.assertEqual(parts[4], ([], reference[3:4])) 291 | self.assertEqual(parts[5], (hypothesis[5:6], reference[4:5])) 292 | self.assertEqual(parts[6], (hypothesis[6:7], [])) 293 | self.assertEqual(parts[7], (hypothesis[7:9], reference[5:7])) 294 | self.assertEqual(parts[8], ([], reference[7:8])) 295 | 296 | 297 | if __name__ == '__main__': 298 | unittest.main() 299 | -------------------------------------------------------------------------------- /tests/test_suber_statistics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from typing import Dict, Any 4 | 5 | from suber.metrics.suber import calculate_SubER 6 | from suber.metrics.suber_statistics import SubERStatisticsCollector 7 | from .utilities import create_temporary_file_and_read_it 8 | 9 | 10 | class SubERStatisticsTests(unittest.TestCase): 11 | """ 12 | Mostly copied from SubERMetricsTests, but now testing the statistics collection. 13 | """ 14 | 15 | def setUp(self): 16 | self._reference1 = """ 17 | 1 18 | 0:00:01.000 --> 0:00:02.000 19 | This is a subtitle.""" 20 | 21 | self._reference2 = """ 22 | 1 23 | 0:00:01.000 --> 0:00:02.000 24 | This is a subtitle. 25 | 26 | 2 27 | 0:00:03.000 --> 0:00:04.000 28 | And another one!""" 29 | 30 | self._statistics_template = { 31 | } 32 | 33 | def _run_test(self, hypothesis, reference, expected_statistics: Dict[str, Any]): 34 | hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) 35 | reference_subtitles = create_temporary_file_and_read_it(reference) 36 | 37 | statistics_collector = SubERStatisticsCollector() 38 | 39 | _ = calculate_SubER( 40 | hypothesis_subtitles, 41 | reference_subtitles, 42 | statistics_collector=statistics_collector) 43 | 44 | statistics = statistics_collector.get_statistics() 45 | 46 | for key, value in expected_statistics.items(): 47 | self.assertIn(key, statistics) 48 | self.assertEqual(value, statistics[key], msg=f"key: {key}") 49 | 50 | def test_empty(self): 51 | expected_statistics = { 52 | "num_reference_words": 0, 53 | "num_reference_breaks": 0, 54 | "num_shifts": 0, 55 | "num_word_deletions": 0, 56 | "num_break_deletions": 0, 57 | "num_word_insertions": 0, 58 | "num_break_insertions": 0, 59 | "num_word_substitutions": 0, 60 | "num_break_substitutions": 0, 61 | } 62 | 63 | self._run_test(hypothesis="", reference="", expected_statistics=expected_statistics) 64 | 65 | def test_split_subtitle_no_overlap(self): 66 | hypothesis = """ 67 | 1 68 | 0:00:00.000 --> 0:00:00.500 69 | This is 70 | 71 | 2 72 | 0:00:02.500 --> 0:00:03.000 73 | a subtitle.""" 74 | 75 | expected_statistics = { 76 | "num_reference_words": 4, 77 | "num_reference_breaks": 1, 78 | "num_shifts": 0, 79 | "num_word_deletions": 4, 80 | "num_break_deletions": 1, 81 | "num_word_insertions": 4, 82 | "num_break_insertions": 2, 83 | "num_word_substitutions": 0, 84 | "num_break_substitutions": 0, 85 | } 86 | 87 | self._run_test(hypothesis, self._reference1, expected_statistics=expected_statistics) 88 | 89 | def test_split_subtitle_one_overlap(self): 90 | hypothesis = """ 91 | 1 92 | 0:00:00.000 --> 0:00:00.500 93 | This is 94 | 95 | 2 96 | 0:00:01.500 --> 0:00:02.000 97 | a subtitle.""" 98 | 99 | expected_statistics = { 100 | "num_reference_words": 4, 101 | "num_reference_breaks": 1, 102 | "num_shifts": 0, 103 | "num_word_deletions": 2, 104 | "num_break_deletions": 0, 105 | "num_word_insertions": 2, 106 | "num_break_insertions": 1, 107 | "num_word_substitutions": 0, 108 | "num_break_substitutions": 0, 109 | } 110 | 111 | self._run_test(hypothesis, self._reference1, expected_statistics=expected_statistics) 112 | 113 | def test_merged_subtitle_with_shift_and_substitution(self): 114 | hypothesis = """ 115 | 1 116 | 0:00:01.000 --> 0:00:04.000 117 | That is a another one! subtitle. 118 | And""" 119 | 120 | expected_statistics = { 121 | "num_reference_words": 7, 122 | "num_reference_breaks": 2, 123 | "num_shifts": 1, 124 | "num_word_deletions": 0, 125 | "num_break_deletions": 0, 126 | "num_word_insertions": 0, 127 | "num_break_insertions": 0, 128 | "num_word_substitutions": 1, 129 | "num_break_substitutions": 1, 130 | } 131 | self._run_test(hypothesis, self._reference2, expected_statistics=expected_statistics) 132 | 133 | def test_split_into_three_with_one_shift(self): 134 | hypothesis = """ 135 | 1 136 | 0:00:01.000 --> 0:00:01.500 137 | This is a 138 | 139 | 2 140 | 0:00:01.500 --> 0:00:03.500 141 | another 142 | subtitle. 143 | 144 | 2 145 | 0:00:03.500 --> 0:00:04.000 146 | And one!""" 147 | 148 | expected_statistics = { 149 | "num_reference_words": 7, 150 | "num_reference_breaks": 2, 151 | "num_shifts": 1, 152 | "num_word_deletions": 0, 153 | "num_break_deletions": 0, 154 | "num_word_insertions": 0, 155 | "num_break_insertions": 2, 156 | "num_word_substitutions": 0, 157 | "num_break_substitutions": 0, 158 | } 159 | 160 | self._run_test(hypothesis, self._reference2, expected_statistics=expected_statistics) 161 | 162 | 163 | if __name__ == '__main__': 164 | unittest.main() 165 | -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | import subprocess 4 | 5 | from suber.constants import END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL 6 | 7 | 8 | class SentenceSegmentationTests(unittest.TestCase): 9 | 10 | def test_srt_to_plain(self): 11 | input_file_content = """ 12 | 1 13 | 00:00:00,000 --> 00:00:01,000 14 | This is a simple first frame. This 15 | 16 | 2 17 | 00:00:01,000 --> 00:00:02,000 18 | is another frame 19 | having two lines.""" 20 | 21 | with tempfile.NamedTemporaryFile(mode="w", suffix=".srt") as temporary_input_file, \ 22 | tempfile.NamedTemporaryFile(mode="w+", suffix=".srt") as temporary_output_file: 23 | temporary_input_file.write(input_file_content) 24 | temporary_input_file.flush() 25 | 26 | subprocess.run( 27 | f"python3 -m suber.tools.srt_to_plain " 28 | f"--input-file {temporary_input_file.name} --output-file {temporary_output_file.name}".split(), 29 | check=True) 30 | 31 | output_file_content = temporary_output_file.readlines() 32 | 33 | self.assertEqual(len(output_file_content), 2) 34 | 35 | self.assertEqual(output_file_content[0].strip(), 36 | f"This is a simple first frame. This {END_OF_BLOCK_SYMBOL}") 37 | self.assertEqual(output_file_content[1].strip(), 38 | f"is another frame {END_OF_LINE_SYMBOL} having two lines. {END_OF_BLOCK_SYMBOL}") 39 | 40 | # Again, now with sentence segmentation. 41 | subprocess.run( 42 | f"python3 -m suber.tools.srt_to_plain --sentence-segmentation " 43 | f"--input-file {temporary_input_file.name} --output-file {temporary_output_file.name}".split(), 44 | check=True) 45 | 46 | temporary_output_file.seek(0) 47 | output_file_content = temporary_output_file.readlines() 48 | 49 | self.assertEqual(len(output_file_content), 2) 50 | 51 | self.assertEqual(output_file_content[0].strip(), 52 | f"This is a simple first frame.") 53 | self.assertEqual(output_file_content[1].strip(), 54 | f"This {END_OF_BLOCK_SYMBOL} is another frame {END_OF_LINE_SYMBOL} " 55 | f"having two lines. {END_OF_BLOCK_SYMBOL}") 56 | 57 | def test_levenshtein_align_hyp_to_ref(self): 58 | reference_file_content = """This is a line. 59 | That is another one. 60 | And a third segment.""" 61 | 62 | hypothesis_file_content = """This is a lines. That this 63 | is another one. And third segment.""" 64 | 65 | with tempfile.NamedTemporaryFile(mode="w", suffix=".srt") as temporary_reference_file, \ 66 | tempfile.NamedTemporaryFile(mode="w", suffix=".srt") as temporary_hypothesis_file, \ 67 | tempfile.NamedTemporaryFile(mode="w+", suffix=".srt") as temporary_output_file: 68 | temporary_reference_file.write(reference_file_content) 69 | temporary_reference_file.flush() 70 | 71 | temporary_hypothesis_file.write(hypothesis_file_content) 72 | temporary_hypothesis_file.flush() 73 | 74 | subprocess.run( 75 | f"python3 -m suber.tools.align_hyp_to_ref --method levenshtein " 76 | f"--hypothesis {temporary_hypothesis_file.name} --reference {temporary_reference_file.name} " 77 | f"--hypothesis-format plain --reference-format plain " 78 | f"--aligned-hypothesis {temporary_output_file.name}".split(), 79 | check=True) 80 | 81 | output_file_content = temporary_output_file.readlines() 82 | 83 | self.assertEqual(len(output_file_content), 3) 84 | self.assertEqual(output_file_content[0].strip(), "This is a lines.") 85 | self.assertEqual(output_file_content[1].strip(), "That this is another one.") 86 | self.assertEqual(output_file_content[2].strip(), "And third segment.") 87 | 88 | def test_time_align_hyp_to_ref(self): 89 | reference_file_content = """ 90 | 1 91 | 00:00:01,000 --> 00:00:02,000 92 | This is a simple first frame. 93 | 94 | 2 95 | 00:00:03,000 --> 00:00:04,000 96 | This is another frame 97 | having two lines.""" 98 | 99 | hypothesis_file_content = """ 100 | 1 101 | 00:00:00,000 --> 00:00:02,000 102 | This is a simple first frame. 103 | 104 | 2 105 | 00:00:02,500 --> 00:00:03,500 106 | This is another frame 107 | 108 | 3 109 | 00:00:03,500 --> 00:00:04,500 110 | having two lines.""" 111 | 112 | with tempfile.NamedTemporaryFile(mode="w", suffix=".srt") as temporary_reference_file, \ 113 | tempfile.NamedTemporaryFile(mode="w", suffix=".srt") as temporary_hypothesis_file, \ 114 | tempfile.NamedTemporaryFile(mode="w+", suffix=".srt") as temporary_output_file: 115 | temporary_reference_file.write(reference_file_content) 116 | temporary_reference_file.flush() 117 | 118 | temporary_hypothesis_file.write(hypothesis_file_content) 119 | temporary_hypothesis_file.flush() 120 | 121 | subprocess.run( 122 | f"python3 -m suber.tools.align_hyp_to_ref --method time " 123 | f"--hypothesis {temporary_hypothesis_file.name} --reference {temporary_reference_file.name} " 124 | f"--hypothesis-format SRT --reference-format SRT " 125 | f"--aligned-hypothesis {temporary_output_file.name}".split(), 126 | check=True) 127 | 128 | output_file_content = temporary_output_file.readlines() 129 | 130 | self.assertEqual(len(output_file_content), 2) 131 | self.assertEqual(output_file_content[0].strip(), "simple first frame.") 132 | self.assertEqual(output_file_content[1].strip(), "another frame having") 133 | 134 | 135 | if __name__ == '__main__': 136 | unittest.main() 137 | -------------------------------------------------------------------------------- /tests/utilities.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from suber.file_readers import PlainFileReader, SRTFileReader 3 | 4 | 5 | def create_temporary_file_and_read_it(file_content, file_format="SRT"): 6 | with tempfile.NamedTemporaryFile(mode="w", suffix=".srt") as temporary_file: 7 | temporary_file.write(file_content) 8 | temporary_file.flush() 9 | 10 | if file_format == "SRT": 11 | file_reader = SRTFileReader(temporary_file.name) 12 | elif file_format == "plain": 13 | file_reader = PlainFileReader(temporary_file.name) 14 | else: 15 | raise ValueError(f"Invalid file format '{file_format}'") 16 | 17 | segments = file_reader.read() 18 | 19 | return segments 20 | --------------------------------------------------------------------------------