├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── mt_metrics_eval
├── __init__.py
├── codalab
│ ├── eval.py
│ └── metadata
├── converters
│ ├── __init__.py
│ ├── evalset_ratings_to_standalone.py
│ ├── score_mqm.py
│ ├── standalone_ratings_to_evalset.py
│ └── verify_scores_file.py
├── data.py
├── data_test.py
├── meta_info.py
├── mt_metrics_eval.ipynb
├── mtme.py
├── pce.py
├── pce_test.py
├── ratings.py
├── ratings_test.py
├── standalone_ratings.py
├── stats.py
├── stats_test.py
├── tasks.py
├── tasks_test.py
├── tau_optimization.py
├── tau_optimization_test.py
├── ties_matter.ipynb
├── wmt22_metrics.ipynb
└── wmt23_metrics.ipynb
└── setup.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MT Metrics Eval V2
2 |
3 | MTME is a simple toolkit to evaluate the performance of Machine Translation
4 | metrics on standard test sets such as those from the
5 | [WMT Metrics Shared Tasks](https://wmt-metrics-task.github.io).
6 | It bundles data relevant to metric development and evaluation for a
7 | given test set and language pair, and lets you do the following:
8 |
9 | - Access source, reference, and MT output text, along with associated
10 | meta-info, for the WMT metrics tasks from 2019 on. This can be done via
11 | software, or by directly accessing the files in a linux directory
12 | structure, in a straightforward format.
13 | - Access human and automatic metric scores for the above data, and MQM ratings
14 | for some language pairs.
15 | - Reproduce the official results from the WMT metrics tasks. For
16 | WMT22 on, there are colabs to do this; other years require more work.
17 | - Compute various correlations and perform significance tests on correlation
18 | differences between two metrics.
19 |
20 | These can be done on the command line using a python script, or from an
21 | API.
22 |
23 | ## Installation
24 |
25 | You need python 3.10 or later. To install:
26 |
27 | ```bash
28 | git clone https://github.com/google-research/mt-metrics-eval.git
29 | cd mt-metrics-eval
30 | pip install .
31 | ```
32 |
33 | ## Downloading the data
34 |
35 | This must be done before using the toolkit. You can either use the mtme script:
36 |
37 | ```bash
38 | alias mtme='python3 -m mt_metrics_eval.mtme'
39 | mtme --download # Puts ~2G of data into $HOME/.mt-metrics-eval.
40 | ```
41 |
42 | Or download directly, if you're only interested in the data:
43 |
44 | ```bash
45 | mkdir $HOME/.mt-metrics-eval
46 | cd $HOME/.mt-metrics-eval
47 | wget https://storage.googleapis.com/mt-metrics-eval/mt-metrics-eval-v2.tgz
48 | tar xfz mt-metrics-eval-v2.tgz
49 | ```
50 |
51 | Once data is downloaded, you can optionally test the install:
52 |
53 | ```bash
54 | python3 -m unittest discover mt_metrics_eval "*_test.py" # Takes ~70 seconds.
55 | ```
56 |
57 | ## Running from the command line
58 |
59 | Here are some examples of things you can do with the mtme script. They assume
60 | that the mtme alias above has been set up.
61 |
62 | Get information about test sets:
63 |
64 | ```bash
65 | mtme --list # List available test sets.
66 | mtme --list -t wmt22 # List language pairs for wmt22.
67 | mtme --list -t wmt22 -l en-de # List details for wmt22 en-de.
68 | ```
69 |
70 | Get contents of test sets. Paste doc-id, source, standard reference,
71 | alternative reference to stdout:
72 |
73 | ```bash
74 | mtme -t wmt22 -l en-de --echo doc,src,refA,refB
75 | ```
76 |
77 | Outputs from all systems, sequentially, pasted with doc-ids, source, and
78 | reference:
79 |
80 | ```bash
81 | mtme -t wmt22 -l en-de --echosys doc,src,refA
82 | ```
83 |
84 | Human and metric scores for all systems, at all granularities:
85 |
86 | ```bash
87 | mtme -t wmt22 -l en-de --scores > wmt22.en-de.tsv
88 | ```
89 |
90 | Evaluate metric score files containing tab-separated `system-name score`
91 | entries. For system-level correlations, supply one score per system. For
92 | document-level or segment-level correlations, supply one score per document or
93 | segment, grouped by system, in the same order as text generated using `--echo`
94 | (the same order as the WMT test-set file). Granularity is determined
95 | automatically. Domain-level scores are currently not supported by
96 | this command.
97 |
98 | ```bash
99 | examples=$HOME/.mt-metrics-eval/mt-metrics-eval-v2/wmt22/metric-scores/en-de
100 |
101 | mtme -t wmt22 -l en-de < $examples/BLEU-refA.sys.score
102 | mtme -t wmt22 -l en-de < $examples/BLEU-refA.seg.score
103 | ```
104 |
105 | Compare to WMT appraise gold scores instead of MQM gold scores:
106 |
107 | ```bash
108 | mtme -t wmt22 -l en-de -g wmt-appraise < $examples/BLEU-refA.sys.score
109 | mtme -t wmt22 -l en-de -g wmt-appraise < $examples/BLEU-refA.seg.score
110 | ```
111 |
112 | Compute correlations for two metrics files, and perform tests to determine
113 | whether they are significantly different:
114 |
115 | ```bash
116 | mtme -t wmt22 -l en-de -i $examples/BLEU-refA.sys.score -c $examples/COMET-22-refA.sys.score
117 | ```
118 |
119 | Compare all known metrics under specified conditions. This corresponds to one of
120 | the "tasks" in the WMT22 metrics evaluation. The first output line contains all
121 | relevant parameter settings, and subsequent lines show metrics in descending
122 | order of performance, followed by the rank of their significance cluster, the
123 | value of the selected correlation statistic, and a vector of flags to indicate
124 | significant differences with lower-ranked metrics. These examples use k_block=5
125 | for demo purposes; using k_block=100 will approximately match official results
126 | but can take minutes to hours to complete, depending on the task.
127 |
128 | ```bash
129 | # System-level Pearson
130 | mtme -t wmt22 -l en-de --matrix --k_block 5
131 |
132 | # System-level paired-rank accuracy, pooling results across all MQM languages
133 | mtme -t wmt22 -l en-de,zh-en,en-ru --matrix \
134 | --matrix_corr accuracy --k_block 5
135 |
136 | # Segment-level item-wise averaged Kendall-Tau-Acc23 with optimal tie threshold
137 | # using sampling rate of 1.0 (disabling significance testing for demo).
138 | mtme -t wmt22 -l en-de --matrix --matrix_level seg --avg item \
139 | --matrix_corr KendallWithTiesOpt --matrix_perm_test pairs \
140 | --matrix_corr_args "{'variant':'acc23', 'sample_rate':1.0}" --k 0
141 | ```
142 |
143 | ## API and Colabs
144 |
145 | The colab notebook `mt_metrics_eval.ipynb` contains examples that show how to
146 | use the API to load and summarize data, and compare stored metrics (ones that
147 | participated in the metrics shared tasks) using different criteria. It also
148 | demonstrates how you can incorporate new metrics into these comparisons.
149 |
150 | The notebooks `wmt22_metrics.ipynb` and `wmt23_metrics.ipynb` document how the
151 | official results for these tasks were generated.
152 | We will try to provide similar notebooks for future evaluations.
153 |
154 | The notebook `ties_matter.ipynb` contains the code to reproduce the results
155 | from [Ties Matter: Meta-Evaluating Modern Metrics with Pairwise Accuracy and Tie Calibration](https://arxiv.org/abs/2305.14324).
156 | It also contains examples for how to calculate the proposed pairwise accuracy
157 | with tie calibration.
158 |
159 | ## MQM Ratings
160 |
161 | MTME also supports representing MQM ratings.
162 | The ratings are stored as `rating.Rating` objects in the `EvalSet`.
163 | They can be accessed via the `EvalSet.Ratings()` function.
164 | `Ratings()` returns a dictionary that maps between the name of a set of
165 | ratings and the ratings themselves, one per segment.
166 | Each entry can either represent:
167 |
168 | - An individual rater's ratings, in which the key is the ID of the rater
169 | - A metric's ratings, in which the key is the ID of the system that predicted the rating
170 | - A combined set of ratings that come from different raters, in which the key
171 | is the name for this group of ratings. This could be used if there was a logical
172 | "round" of ratings from different raters, like a full round of ratings collected
173 | as part of a WMT evaluation.
174 |
175 | The IDs of the raters who rated the segments can be accessed via
176 | `EvalSet.RaterIdsPerSeg()`. It returns a dict that is parallel to an entry
177 | in `EvalSet.Ratings()` that lists the individual rater IDs for each rating or
178 | `None` if there was no rating.
179 | For an individual rater's ratings or a metric's ratings, these are typically
180 | that rater's ID or the name of the metric. For a combined set of ratings, this
181 | will contain the per-segment rater IDs.
182 |
183 | For each year of WMT for which ratings are included in MTME, there is a rating
184 | entry for each individual rater. If there was a logical grouping of ratings,
185 | like a round of ratings that were collected at the same time, those are also
186 | included.
187 | Here are the ratings that are currently available:
188 |
189 | | Dataset | Language Pair | Ratings |
190 | | ------- | ------------- | ------- |
191 | | wmt20 | en-de |
- "mqm.rater1"-"mqm.rater6": The individual rater's ratings. Each segment was rated up to 3 times, and there is no clear definition of a round of ratings, so no combined set of ratings is included.
|
192 | | wmt20 | zh-en | - "mqm.rater1"-"mqm.rater6": The individual rater's ratings. Each segment was rated up to 3 times, and there is no clear definition of a round of ratings, so no combined set of ratings is included.
|
193 | | wmt21.news | en-de | - "mqm.rater1"-"mqm.rater14": The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-14 that were used in the WMT evaluation
|
194 | | wmt21.news | zh-en | - "mqm.rater1"-"mqm.rater9": The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-9 that were used in the WMT evaluation
|
195 | | wmt21.tedtalks | en-de | - "mqm.rater1"-"mqm.rater4": The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-4 that were used in the WMT evaluation
|
196 | | wmt21.tedtalks | zh-en | - mqm.rater1-mqm.rater9: The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-9 that were used in the WMT evaluation
|
197 | | wmt22 | en-de | - "mqm.rater1"-"mqm.rater7": The individual rater's ratings (from all rounds; see below)
- "mqm.merged": The combined ratings of rater1-7 that were used in the WMT evaluation
- "round2.mqm.merged": A second round of ratings collected from rater1-7 (these were not part of the WMT evaluation)
- "round3.mqm.merged": A third round of ratings collected from rater1-7 (these were not part of the WMT evaluation)
|
198 | | wmt22 | en-ru | - "mqm.rater1"-"mqm.rater4": The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-4 that were used in the WMT evaluation
|
199 | | wmt22 | zh-en | - "mqm.rater1"-"mqm.rater12": The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-12 that were used in the WMT evaluation
|
200 | | wmt23 | en-de | - "mqm.rater1"-"mqm.rater10": The individual rater's ratings
- "mqm.merged": The combined ratings of rater1-10 that were used in the WMT evaluation
|
201 | | wmt23 | zh-en | - "mqm.rater1"-"mqm.rater8": The individual rater's ratings. A small subset of segments were rated by all of the raters, so there is no clear definition of a round of ratings, so no merged set of ratings is included.
|
202 |
203 | Note that the ratings might differ slightly from the ratings that were released
204 | as part of the original WMT evaluations. The released data and the translations
205 | in MTME were sometimes different (e.g., punctuation was introduced or removed,
206 | whitespace inserted, etc.), which made it difficult to map the MQM ratings to
207 | character offsets in the MTME translations.
208 | We wrote scripts to fix the ratings so they would match the MTME versions, but
209 | this was sometimes lossy and not always possible, so some ratings might be
210 | different or even missing.
211 | This is less of a problem with more recent WMT years.
212 |
213 |
214 | ## WMT24++ Data
215 |
216 | This package also contains the system outputs and metric scores that were
217 | collected as part of the paper [WMT24++: Expanding the Language Coverage of WMT24 to 55 Languages & Dialects](https://arxiv.org/abs/2502.12404).
218 | The data can be accessed via the EvalSet for `"wmt24pp"` (e.g., `EvalSet("wmt24pp", "en-de_DE")`).
219 | Due to restrictions, we are unable to release the outputs from all of the MT
220 | systems and LLMs that were reported in the paper.
221 | The set of systems that were released can be accessed via the `sys_names` property
222 | of the EvalSet.
223 |
224 | If you use this data, please cite
225 |
226 | ```
227 | @misc{wmt24pp,
228 | title={{WMT24++: Expanding the Language Coverage of WMT24 to 55 Languages & Dialects}},
229 | author={Daniel Deutsch and Eleftheria Briakou and Isaac Caswell and Mara Finkelstein and Rebecca Galor and Juraj Juraska and Geza Kovacs and Alison Lui and Ricardo Rei and Jason Riesa and Shruti Rijhwani and Parker Riley and Elizabeth Salesky and Firas Trabelsi and Stephanie Winkler and Biao Zhang and Markus Freitag},
230 | year={2025},
231 | eprint={2502.12404},
232 | archivePrefix={arXiv},
233 | primaryClass={cs.CL},
234 | url={https://arxiv.org/abs/2502.12404},
235 | }
236 | ```
237 |
238 |
239 | ## Conversion scripts
240 |
241 | The `converters` module contains scripts to convert between different formats
242 | for ratings and scores.
243 |
244 | For example, to convert MQM annotations from [Google's tsv annotation format](
245 | https://github.com/google/wmt-mqm-human-evaluation) into scores:
246 |
247 | ```bash
248 | git clone https://github.com/google/wmt-mqm-human-evaluation
249 | python3 -m mt_metrics_eval.converters.score_mqm \
250 | --weights "major:5 minor:1 No-error:0 minor/Fluency/Punctuation:0.1" \
251 | < wmt-mqm-human-evaluation/generalMT2022/ende/mqm_generalMT2022_ende.tsv \
252 | > mqm.ende.seg.score
253 | ```
254 |
255 | To convert MTME-format MQM annotations into standalone json files that bundle
256 | all relevant information:
257 |
258 | ```bash
259 | python3 -m mt_metrics_eval.converters.evalset_ratings_to_standalone \
260 | --evalset_ratings_files $HOME/.mt-metrics-eval/mt-metrics-eval-v2/wmt23/human-scores/en-de.mqm.merged.seg.rating \
261 | --language_pair en-de \
262 | --test_set wmt23 \
263 | --ratings_file en-de.mqm.standalone.jsonl
264 | ```
265 |
266 | ## File organization and naming convention
267 |
268 | ### Overview
269 |
270 | There is one top-level directory for each test set (e.g. `wmt22`).
271 | Each top-level directory contains the following sub-directories (whose contents
272 | should be obvious from their names):
273 | `documents`, `human-scores`, `metric-scores`, `references`, `sources`, and
274 | `system-outputs`.
275 |
276 | In general, a test-set contains data from many language pairs. Each combination
277 | of test-set and language pair (eg wmt22 + de-en) is called an **EvalSet**. This
278 | is the main unit of computation in the toolkit. Each EvalSet consists of a
279 | source text (divided into one or more documents, optionally with domain
280 | membership), reference translations, system outputs to be scored, human gold
281 | scores, and metric scores.
282 |
283 | Meta information is encoded into directory and file names as specified below.
284 | The convention is intended to be straightforward, but there are a few
285 | subtleties:
286 |
287 | - Reference translations can be scored as system outputs. When this is the case,
288 | **the reference files should be copied into the system-outputs directory with
289 | matching names**. For example:
290 | ```
291 | references/de-en.refb.txt → system-outputs/de-en/refb.txt
292 | ```
293 | - Metrics can come in different variants according to which reference(s) they
294 | used. This information is encoded into their filenames. To facilitate parsing,
295 | reference names can't contain dashes or dots, as outlined below.
296 | - Metric files must contain scores for all files in the system output directory,
297 | except those that were used as references.
298 | - Human score files don’t have to contain entries for all systems, or even for
299 | all segments for a given system. Missing entries are marked with ‘None’ strings.
300 |
301 | ### Specification
302 |
303 | The filename format and content specification for each kind of file are
304 | described below. Paths are relative to the top-level directory corresponding to
305 | a test set, e.g. wmt20. SRC and TGT designate abbreviations for the
306 | source and target language, e.g. ‘en’. Blanks designate any amount of
307 | whitespace.
308 |
309 | - source text:
310 | - filename: `sources/SRC-TGT.txt`
311 | - per-line contents: text segment
312 | - document meta-info:
313 | - filename: `documents/SRC-TGT.docs`
314 | - per-line contents: DOMAIN DOCNAME
315 | - lines match those in the source file
316 | - documents are assumed to be contiguous blocks of segments
317 | - DOMAIN tags can be repurposed for categories other than domain, but
318 | each document must belong to only one category
319 | - references:
320 | - filename: `references/SRC-TGT.NAME.txt`
321 | - NAME is the name of this reference, e.g. `refb`. Names cannot be the
322 | reserved strings `all` or `src`, or contain `.` or `-` characters.
323 | - per-line contents: text segment
324 | - lines match those in the source file
325 | - system outputs:
326 | - filename: `system-outputs/SRC-TGT/NAME.txt`
327 | - NAME is the name of an MT system or reference
328 | - per-line contents: text segment
329 | - lines match those in the source file
330 | - human scores:
331 | - filename: `human-scores/SRC-TGT.NAME.LEVEL.score`
332 | - NAME describes the scoring method, e.g. `mqm` or `wmt-z`.
333 | - LEVEL indicates the granularity of the scores, one of `sys`, `domain`,
334 | `doc`, or `seg`.
335 | - per-line contents: [DOMAIN] SYSNAME SCORE
336 | - DOMAIN is present only if granularity is `domain`
337 | - SYSNAME must match a NAME in system outputs
338 | - SCORE may be `None` to indicate a missing score
339 | - System-level (`sys`) files contain exactly one score per system.
340 | - Domain-level (`domain`) files contain one score per domain and system.
341 | - Document-level (`doc`) files contain a block of scores for each system.
342 | Each block contains the scores for successive documents, in the same order
343 | they occur in the document info file.
344 | - Segment-level (`seg`) files contain a block of scores for each system.
345 | Each block contains the scores for all segments in the system output file,
346 | in order.
347 | - human MQM ratings:
348 | - filename: `human-scores/SRC-TGT.RATING_NAME.seg.rating`
349 | - RATING_NAME describes the name for the collection of ratings. This can be the name of an individual rater or a name like "mqm.merged", which means multiple rater's ratings have been merged into a single collection of ratings.
350 | - per-line contents: SYSNAME RATING [RATER_ID]
351 | - SYSNAME must match a NAME in system outputs
352 | - A JSON-serialized `ratings.Rating` object or "None" if there is no rating for the given segment.
353 | - RATER_ID (optional) marks which rater did the rating. If not provided, RATING_NAME is used.
354 | - metric scores:
355 | - filename `metric-scores/SRC-TGT/NAME-REF.LEVEL.score`
356 | - NAME is the metric’s base name.
357 | - REF describes the reference(s) used for this version of the metric,
358 | either:
359 | - A list of one or more names separated by `.`, eg `refa` or
360 | `refa.refb`.
361 | - The special string `src` to indicate that no reference was used.
362 | - The special string `all` to indicate that all references were used.
363 | - LEVEL indicates the granularity of the scores, one of `sys`, `domain`,
364 | `doc`, or `seg`.
365 | - per-line contents: [DOMAIN] SYSNAME SCORE
366 | - Format is identical to human scores, except that `None` entries aren't
367 | permitted.
368 |
369 | ## Credits
370 |
371 | Inspired by and loosely modeled on
372 | [SacreBLEU](https://github.com/mjpost/sacrebleu).
373 |
374 | This is not an official Google product.
375 |
--------------------------------------------------------------------------------
/mt_metrics_eval/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/mt_metrics_eval/codalab/eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #!/usr/bin/env python3
16 |
17 | """Script to score codalab submissions to PHASE 1 (open) at WMT Metrics 2023.
18 |
19 | Usage: eval.py INPUT OUTPUT
20 |
21 | INPUT and OUTPUT are directories assumed to be structured as follows:
22 | INPUT/ref - subdirectory containing reference files:
23 | SEG_REF_FILE - segment-level reference scores
24 | SYS_REF_FILE - system-level reference scores
25 | REF_FREE_SEG_REF_FILE - segment-level reference scores for QE metrics
26 | REF_FREE_SYS_REF_FILE - system-level reference scores for QE metrics
27 | INPUT/res - subdirectory containing submitted files matching the pattern:
28 | *SEG_RES_SUFFIX - segment-level metric scores
29 | *SYS_RES_SUFFIX - system-level metric scores
30 | META_FILE - meta info file
31 | OUTPUT/OUT_FILE - where results get written
32 |
33 | The filenames are defined as global variables below. See comments above the
34 | definitions for details about expected formats.
35 |
36 | Assumptions and limitations:
37 |
38 | - Accepts one system-level score file, or one segment-level score file, or
39 | both.
40 | - Score files are identified by the suffixes defined below. The rest of the
41 | filename is ignored; the metric name is read from the file contents.
42 | - Each submission must contain only one metric name, the same name for system-
43 | and segment-level scores if both are provided.
44 | - Each submission may be either reference-free or reference-based. Reference-
45 | free submissions are indicated by 'src' values in the reference field. If a
46 | submission is reference-free, it must be reference-free for all language pairs
47 | and granularities.
48 | - Currently, only entries pertaining to the offical language pairs, test-set,
49 | and reference, are read and checked in detail. Only basic checks are performed
50 | on other entries for the standard test set. Challenge-set entries are ignored.
51 | - Error checking ensures that the submission is consistent with information
52 | from the reference score files (system names, number of segments, etc). The
53 | program throws an exception when the first error is encountered.
54 | - The segment-level reference score files are assumed to contain scores for all
55 | segments. In general, this will not be the case for actual gold scores; this
56 | program is currently intended for the initial submission phase only.
57 | - Unlike the submission files, reference score files are allowed to have
58 | multiple metric names, in order to accommodate different gold score names for
59 | different language pairs (eg, MQM or DA scores). But there can be at most one
60 | name per LP.
61 | - System-level Pearson correlations between reference and submission scores are
62 | written to the output file. If no system-level scores are provided, these are
63 | derived from averaged segment-level scores. The system-level scores pertain to
64 | the 'all' (whole-test-set) domain only.
65 | - Segment-level Kendall correlations between reference and submission scores are
66 | written to the output file. If no segment-level scores are provided, these are
67 | just 0. The correlations are computed by flattening system x segment scores
68 | into single vectors.
69 | """
70 |
71 | import collections
72 | import dataclasses
73 | import os
74 | import sys
75 | import numpy as np
76 | import scipy.stats
77 |
78 | # Globals
79 |
80 | # Files containing pseudo-reference scores, at segment and system granularity,
81 | # in standard metrics submission format. See https://wmt-metrics-task.github.io
82 | SEG_REF_FILE = 'goldlabels.seg.score'
83 | SYS_REF_FILE = 'goldlabels.sys.score'
84 | REF_FREE_SEG_REF_FILE = 'goldlabels.reffree.seg.score'
85 | REF_FREE_SYS_REF_FILE = 'goldlabels.reffree.sys.score'
86 |
87 | # Suffixes of the files containing metric scores, in standard metrics
88 | # submission format. The prefix is an arbitrary name for the submitted metric.
89 | SEG_RES_SUFFIX = '.seg.score'
90 | SYS_RES_SUFFIX = '.sys.score'
91 |
92 | # Meta-info filename for current submission. If the file does not exist, no
93 | # metadata will be read. Otherwise, it must contain at least the following
94 | # entries (any other etries are ignored):
95 | # team: NAME
96 | # primary: Y[es]|N[o]
97 | META_FILE = 'metadata.txt'
98 |
99 | # Place to write results. Fields written (last two are expected by codalab,
100 | # team and primary only written if META_FILE was found):
101 | # - team: NAME
102 | # - metric_name: NAME
103 | # - ref_less: Y|N
104 | # - primary: Y|N
105 | # - LP_pearson: PEARSON
106 | # - LP_kendalltau: KENDALL
107 | # where LP is one of zhen, heen or ende (note, no hyphen).
108 | OUT_FILE = 'scores.txt'
109 |
110 | # Map official language pairs to standard references.
111 | LANG_PAIR_TO_REF = {
112 | 'en-de': 'refA',
113 | 'he-en': 'refA',
114 | 'zh-en': 'refA'
115 | }
116 |
117 | # Test set we're reading.
118 | TEST_SET = 'generaltest2023'
119 |
120 | # Name for the global domain.
121 | GLOBAL_DOMAIN = 'all'
122 |
123 |
124 | @dataclasses.dataclass
125 | class BasicInfo:
126 | """Collection of basic test set info for a given language pair."""
127 |
128 | # pylint: disable=g-bare-generic
129 | domains: set = dataclasses.field(default_factory=set)
130 | docs: set = dataclasses.field(default_factory=set)
131 | refs: set = dataclasses.field(default_factory=set)
132 | systems: set = dataclasses.field(default_factory=set)
133 | num_segs: int = 0
134 |
135 | def add(self, testset, domain, doc, ref, sysname, segno) -> None:
136 | if testset != TEST_SET:
137 | return
138 | self.domains.add(domain)
139 | self.docs.add(doc)
140 | self.refs.add(ref)
141 | self.systems.add(sysname)
142 | if segno is not None:
143 | self.num_segs = max(int(segno), self.num_segs)
144 |
145 | def check(self, ref, lp) -> bool:
146 | """Check contents against reference info."""
147 | if self.domains != ref.domains:
148 | raise ValueError(
149 | f'{lp} domains don\'t match std: {self.domains} vs {ref.domains}')
150 | if self.docs != ref.docs:
151 | raise ValueError(
152 | f'{lp} documents don\'t match standard: {self.docs} vs {ref.docs}')
153 | if self.refs != ref.refs:
154 | raise ValueError(
155 | f'{lp} references don\'t match standard: {self.refs} vs {ref.refs}')
156 | if self.num_segs != ref.num_segs:
157 | raise ValueError(
158 | f'{lp} segment count doesn\'t match standard: '
159 | f'{self.num_segs} vs {ref.num_segs}')
160 | return True
161 |
162 |
163 | def read_metadata(filename: str, required_keys_only=True):
164 | """Read and check metadata file."""
165 |
166 | metadata = {}
167 | if os.path.exists(filename):
168 | required = {'team', 'primary'}
169 | with open(filename) as f:
170 | for line in f:
171 | line = line.strip()
172 | if not line: continue
173 | k, v = line.split(maxsplit=1)
174 | if k.endswith(':'):
175 | k = k[:-1]
176 | k = k.lower()
177 | if required_keys_only and k not in required:
178 | continue
179 | metadata[k] = v
180 |
181 | missing = [k for k in required if k not in metadata]
182 | if missing:
183 | missing = ', '.join(f'"{k}"' for k in missing)
184 | raise ValueError(f'Missing entries in {META_FILE}: {missing}')
185 |
186 | for k in ['primary']:
187 | if metadata[k].lower() in ['y', 'yes']:
188 | metadata[k] = 'Y'
189 | elif metadata[k].lower() in ['n', 'no']:
190 | metadata[k] = 'N'
191 | else:
192 | raise ValueError(f'Value for "{k}" must be Y or N')
193 |
194 | primary = metadata['primary'] == 'Y'
195 | primary_msg = f'{"" if primary else "non-"}primary submission'
196 | print(f'Read metadata from {META_FILE} - {primary_msg}')
197 | else:
198 | print(f'{META_FILE} not found')
199 |
200 | return metadata
201 |
202 |
203 | def get_result_filenames(res_dir):
204 | """Find and check result file names for this submission."""
205 | submitted_files = os.listdir(res_dir)
206 | seg_level = [f for f in submitted_files if f.endswith(SEG_RES_SUFFIX)]
207 | sys_level = [f for f in submitted_files if f.endswith(SYS_RES_SUFFIX)]
208 | if len(seg_level) > 1 or len(sys_level) > 1:
209 | raise ValueError(
210 | 'Submission has multiple system- or segment-level score files')
211 | elif not seg_level and not sys_level:
212 | raise ValueError(
213 | f'At least one of METRIC{SEG_RES_SUFFIX} or METRIC{SYS_RES_SUFFIX} '
214 | 'must be supplied.')
215 | seg_res_file = seg_level[0] if seg_level else None
216 | sys_res_file = sys_level[0] if sys_level else None
217 | return seg_res_file, sys_res_file
218 |
219 |
220 | def in_scope(lp, ref, sysname, testset):
221 | """Return True if a score-file entry with these attributes is in scope."""
222 | return (lp in LANG_PAIR_TO_REF and
223 | ref in {LANG_PAIR_TO_REF[lp], 'src'} and
224 | ref != sysname and
225 | testset == TEST_SET)
226 |
227 |
228 | def read_seg_scores(filename: str):
229 | """Read and check standard-format segment-level scores."""
230 |
231 | scores = {} # lp -> sys -> seg -> score
232 | metrics = {} # lp -> metric-name, ref
233 | infos = collections.defaultdict(BasicInfo) # lp -> BasicInfo
234 | with open(filename) as f:
235 | for line in f:
236 | fields = line.strip().split('\t')
237 | if len(fields) != 9:
238 | raise ValueError(f'Expecting 9 tab-separated fields: {line}')
239 | metric, lp, testset, domain, doc, ref, sysname, segno, score = fields
240 | infos[lp].add(testset, domain, doc, ref, sysname, segno)
241 | if not in_scope(lp, ref, sysname, testset): continue
242 | if lp not in scores:
243 | scores[lp] = {}
244 | metrics[lp] = (metric, ref)
245 | if metric != metrics[lp][0]:
246 | raise ValueError(f'Multiple metric names provided for {lp}')
247 | if ref != metrics[lp][1]:
248 | raise ValueError(
249 | f'Metric has both source- and reference-baed versions for {lp}')
250 | if sysname not in scores[lp]:
251 | scores[lp][sysname] = {}
252 | segno = int(segno) - 1 # Original numbers are 1-based
253 | if segno in scores[lp][sysname]:
254 | raise ValueError(f'Duplicate segment number in {line}: {segno + 1}')
255 | scores[lp][sysname][segno] = float(score)
256 |
257 | # Convert to np format for convenience
258 | new_scores = {}
259 | for lp, syslist in scores.items():
260 | matrix = [] # system x segment scores
261 | syslist = sorted(syslist)
262 | for sysname in syslist:
263 | num_segs = max(scores[lp][sysname]) + 1
264 | ordered_scores = [None] * num_segs
265 | for i, s in scores[lp][sysname].items():
266 | ordered_scores[i] = s
267 | if None in ordered_scores:
268 | m = ordered_scores.count(None)
269 | raise ValueError(f'Missing {m} segment score(s) for {lp}/{sysname}')
270 | matrix.append(ordered_scores)
271 | if len(matrix[-1]) != len(matrix[0]):
272 | raise ValueError(f'Length mismatch for {lp}/{sysname} segment scores')
273 | new_scores[lp] = metrics[lp], syslist, np.array(matrix)
274 |
275 | # Return lp -> ((metric, ref), syslist, sys_x_seg score matrix)
276 | # NB: syslist corresponds to matrix rows, and is sorted so as to be comparable
277 | # across different score files.
278 | return new_scores, infos
279 |
280 |
281 | def read_sys_scores(filename: str):
282 | """Read and check standard-format system-level scores."""
283 |
284 | scores = {} # lp -> sys -> domain -> score
285 | metrics = {} # lp -> metric name, ref
286 | infos = collections.defaultdict(BasicInfo) # lp -> BasicInfo
287 | with open(filename) as f:
288 | for line in f:
289 | fields = line.strip().split('\t')
290 | if len(fields) != 7:
291 | raise ValueError(f'Expecting 7 tab-separated fields: {line}')
292 | metric, lp, testset, domain, ref, sysname, score = fields
293 | infos[lp].add(testset, domain, None, ref, sysname, None)
294 | if not in_scope(lp, ref, sysname, testset): continue
295 | if lp not in scores:
296 | scores[lp] = {}
297 | metrics[lp] = (metric, ref)
298 | if metric != metrics[lp][0]:
299 | raise ValueError(f'Multiple metric names provided for {lp}')
300 | if ref != metrics[lp][1]:
301 | raise ValueError(
302 | f'Metric has both source- and reference-based versions for {lp}')
303 | if sysname not in scores[lp]:
304 | scores[lp][sysname] = {}
305 | if domain in scores[lp][sysname]:
306 | raise ValueError(f'Duplicate domain in {line}: {domain}')
307 | scores[lp][sysname][domain] = float(score)
308 |
309 | # Convert to np format for convenience
310 | new_scores = {}
311 | for lp, syslist in scores.items():
312 | matrix = [] # system x domain scores
313 | syslist, domainlist = sorted(syslist), None
314 | for sysname in syslist:
315 | if domainlist is None:
316 | domainlist = sorted(scores[lp][sysname])
317 | elif domainlist != sorted(scores[lp][sysname]):
318 | raise ValueError(f'Mismatched domains in {lp}/{sysname}: {domainlist}')
319 | matrix.append([float(scores[lp][sysname][d]) for d in domainlist])
320 | new_scores[lp] = metrics[lp], syslist, domainlist, np.array(matrix)
321 |
322 | # Return lp -> ((metric, ref), syslist, domainlist, sys_x_domain score matrix)
323 | # NB: syslist, domainlist are sorted, so comparable across different score
324 | # files.
325 | return new_scores, infos
326 |
327 |
328 | def check_uniqueness(results_scores):
329 | """Check uniqueness for metric/reference combinations across languages.
330 |
331 | We allow different metrics/refs for different LPs in reference scores (eg,
332 | DA vs MQM), but require only one metric across all LPs in submissions. The
333 | metric can use different official references for different langauges, but
334 | if is reference-free (ref = 'src'), it must be reference-free across all LPs.
335 |
336 | Args:
337 | results_scores: Return from read_*_scores(), maps lp -> (metric, ref), ...
338 |
339 | Returns:
340 | metric_name, is_ref_free
341 | """
342 | metrics = set(x[0][0] for x in results_scores.values())
343 | refs = set(x[0][1] for x in results_scores.values())
344 | if (len(metrics) > 1):
345 | raise ValueError(f'Found multiple metrics: {metrics}')
346 | metric = list(metrics)[0]
347 | if 'src' in refs and len(refs) > 1:
348 | raise ValueError(
349 | 'Metric has both source- and reference-based segment-level versions')
350 | return metric, 'src' in refs
351 |
352 |
353 | def check_coverage(results_scores, primary):
354 | """Check language-pair coverage for results."""
355 | if primary:
356 | if not set(LANG_PAIR_TO_REF).issubset(results_scores):
357 | raise ValueError(
358 | f'Primary metrics must provide results for {set(LANG_PAIR_TO_REF)}')
359 | else:
360 | pass # Currently anything goes for non-primary submissions.
361 |
362 |
363 | def main(argv):
364 | _, input_dir, output_dir = argv
365 | ref_dir = os.path.join(input_dir, 'ref')
366 | res_dir = os.path.join(input_dir, 'res')
367 |
368 | def read_ref_scores(level, ref_free):
369 | if level == 'seg' and not ref_free:
370 | return read_seg_scores(os.path.join(ref_dir, SEG_REF_FILE))
371 | elif level == 'seg'and ref_free:
372 | return read_seg_scores(os.path.join(ref_dir, REF_FREE_SEG_REF_FILE))
373 | elif level == 'sys' and not ref_free:
374 | return read_sys_scores(os.path.join(ref_dir, SYS_REF_FILE))
375 | elif level == 'sys' and ref_free:
376 | return read_sys_scores(os.path.join(ref_dir, REF_FREE_SYS_REF_FILE))
377 | else:
378 | assert False
379 |
380 | def print_summary(metric, ref_free, res_scores, infos):
381 | print(f'- Metric is {metric}, ref-{"free" if ref_free else "based"}')
382 | print(f'- Read scores for official languages: {",".join(res_scores)}')
383 | others = set(infos) - set(res_scores)
384 | print(f'- Read scores for other languages: {others if others else "None"}',
385 | flush=True)
386 |
387 | # Read metadata
388 | metainfo = read_metadata(
389 | os.path.join(res_dir, META_FILE), required_keys_only=True)
390 | primary = 'primary' in metainfo and metainfo['primary'] == 'Y'
391 |
392 | seg_ref_scores, sys_ref_scores = None, None
393 |
394 | # Read and check submission files.
395 | #
396 | seg_res_file, sys_res_file = get_result_filenames(res_dir)
397 | #
398 | seg_metric, seg_res_scores, seg_ref_free = None, None, None
399 | sys_metric, sys_res_scores, sys_ref_free = None, None, None
400 | if seg_res_file:
401 | print(f'Reading and checking {seg_res_file}:', flush=True)
402 | seg_res_scores, seg_infos = read_seg_scores(
403 | os.path.join(res_dir, seg_res_file))
404 | seg_metric, seg_ref_free = check_uniqueness(seg_res_scores)
405 | check_coverage(seg_res_scores, primary)
406 | seg_ref_scores, seg_ref_infos = read_ref_scores('seg', seg_ref_free)
407 | sys_ref_scores, sys_ref_infos = read_ref_scores('sys', seg_ref_free)
408 | for lp, (_, syslist, matrix) in seg_res_scores.items():
409 | if syslist != seg_ref_scores[lp][1]:
410 | raise ValueError(f'System list for {lp} doesn\'t match reference: '
411 | f'{syslist} vs {seg_ref_scores[lp][1]}')
412 | num_segs = matrix.shape[1]
413 | if num_segs != seg_res_scores[lp][2].shape[1]:
414 | raise ValueError(f'Num segments for {lp} doesn\'t match reference '
415 | f'{num_segs} vs {seg_res_scores[lp][2].shape[1]}')
416 | for lp, info in seg_infos.items():
417 | if lp not in seg_ref_infos:
418 | raise ValueError(f'Unknown segment-level language pair: {lp}')
419 | info.check(seg_ref_infos[lp], lp)
420 | print_summary(seg_metric, seg_ref_free, seg_res_scores, seg_infos)
421 | #
422 | if sys_res_file:
423 | print(f'Reading and checking {sys_res_file}:', flush=True)
424 | sys_res_scores, sys_infos = read_sys_scores(
425 | os.path.join(res_dir, sys_res_file))
426 | sys_metric, sys_ref_free = check_uniqueness(sys_res_scores)
427 | check_coverage(sys_res_scores, primary)
428 | sys_ref_scores, sys_ref_infos = read_ref_scores('sys', sys_ref_free)
429 | if seg_ref_scores is None:
430 | seg_ref_scores, seg_ref_infos = read_ref_scores('seg', sys_ref_free)
431 | for _, _, domainlist, _ in sys_ref_scores.values():
432 | assert GLOBAL_DOMAIN in domainlist
433 | for lp, (_, syslist, domainlist, _) in sys_res_scores.items():
434 | if syslist != sys_ref_scores[lp][1]:
435 | raise ValueError(f'System list for {lp} doesn\'t match reference: '
436 | f'{syslist} vs {sys_ref_scores[lp][1]}')
437 | # Currently only using GLOBAL_DOMAIN, but ensure we have scores for all
438 | # domains in the reference.
439 | if domainlist != sys_ref_scores[lp][2]:
440 | raise ValueError(f'Domain list for {lp} doesn\'t match reference: '
441 | f'{domainlist} vs {sys_ref_scores[lp][2]}')
442 | for lp, info in sys_infos.items():
443 | if lp not in sys_ref_infos:
444 | raise ValueError(f'Unknown system-level language pair: {lp}')
445 | info.check(sys_ref_infos[lp], lp)
446 | print_summary(sys_metric, sys_ref_free, sys_res_scores, sys_infos)
447 | #
448 | if seg_res_file and sys_res_file:
449 | if seg_metric != sys_metric:
450 | raise ValueError(
451 | f'System/segment metric name mismatch: {sys_metric} vs {seg_metric}')
452 | if seg_ref_free != sys_ref_free:
453 | raise ValueError(
454 | f'System/segment ref-free mismatch: {sys_ref_free} vs {seg_ref_free}')
455 | metric_name = seg_metric if seg_res_file else sys_metric
456 | ref_free = seg_ref_free if seg_res_file else sys_ref_free
457 |
458 | # Create sys-level scores by averaging segment-level scores if no sys-level
459 | # scores supplied.
460 | if not sys_res_file:
461 | print('No system-level scores supplied - averaging segment-level scores')
462 | sys_res_scores = {}
463 | for lp, (metric, syslist, matrix) in seg_res_scores.items():
464 | sys_res_scores[lp] = (
465 | metric, syslist, [GLOBAL_DOMAIN], matrix.mean(axis=1, keepdims=True))
466 |
467 | # Compute results
468 |
469 | def make_key(lp, corr):
470 | lp = lp.replace('-', '')
471 | return f'{lp}_{corr}'
472 |
473 | results = {}
474 | print('Computing system-level Pearson correlations with pseudo gold scores')
475 | for lp, (_, _, domainlist, matrix) in sys_res_scores.items():
476 | _, _, ref_domainlist, ref_matrix = sys_ref_scores[lp]
477 | scores = matrix[:, domainlist.index(GLOBAL_DOMAIN)]
478 | ref_scores = ref_matrix[:, ref_domainlist.index(GLOBAL_DOMAIN)]
479 | results[make_key(lp, 'pearson')] = scipy.stats.pearsonr(
480 | scores, ref_scores)[0]
481 |
482 | if seg_res_file:
483 | print('Computing seg-level Kendall correlations with pseudo gold scores')
484 | for lp, (_, _, matrix) in seg_res_scores.items():
485 | _, _, ref_matrix = seg_ref_scores[lp]
486 | scores, ref_scores = matrix.flatten(), ref_matrix.flatten()
487 | results[make_key(lp, 'kendalltau')] = scipy.stats.kendalltau(
488 | scores, ref_scores)[0]
489 | else:
490 | for lp in seg_ref_scores:
491 | results[make_key(lp, 'kendalltau')] = 0
492 |
493 | # Write results
494 | with open(os.path.join(output_dir, OUT_FILE), 'w') as f:
495 | for k, v in metainfo.items():
496 | f.write(f'{k}: {v}\n')
497 | f.write(f'metric_name: {metric_name}\n')
498 | f.write(f'ref_less: {"Y" if ref_free else "N"}\n')
499 | for c, s in results.items():
500 | f.write(f'{c}: {s:f}\n')
501 |
502 | # Run
503 | if __name__ == '__main__':
504 | main(sys.argv)
505 |
--------------------------------------------------------------------------------
/mt_metrics_eval/codalab/metadata:
--------------------------------------------------------------------------------
1 | command: python3 $program/eval.py $input $output
2 | description: Competition evaluation program. Computes Pearson and Kendall correlation.
3 |
--------------------------------------------------------------------------------
/mt_metrics_eval/converters/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/mt_metrics_eval/converters/evalset_ratings_to_standalone.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert standalone ratings file to EvalSet format files."""
16 |
17 | import os
18 | from absl import app
19 | from absl import flags
20 | from mt_metrics_eval import data
21 | from mt_metrics_eval import ratings
22 | from mt_metrics_eval import standalone_ratings
23 | import glob
24 |
25 | flags.DEFINE_list(
26 | 'evalset_ratings_files', None,
27 | 'Comma-separated list of evalset-format ratings files to read. Filenames '
28 | 'are assumed to be in the form {language_pair}.{name}.seg.rating, where '
29 | 'name is of the form {prefix}.{rater}.',
30 | required=True)
31 | flags.DEFINE_string(
32 | 'ratings_file', None,
33 | 'Standalone ratings jsonl file to write.', required=True)
34 | flags.DEFINE_string(
35 | 'test_set', None, 'Test set, eg wmt20.', required=True)
36 | flags.DEFINE_string(
37 | 'language_pair', None, 'Language pair, eg en-de.', required=True)
38 | flags.DEFINE_string(
39 | 'rater_key_file', None,
40 | 'Use a rater_key_file previously written by standalone_ratings_to_evalset '
41 | 'to deanonymize raters in the output file.')
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | def main(argv):
47 | if len(argv) > 1:
48 | raise app.UsageError('Too many command-line arguments.')
49 |
50 | raters_key = {}
51 | if FLAGS.rater_key_file:
52 | with open(FLAGS.rater_key_file, 'r') as f:
53 | for line in f:
54 | k, v = line.rstrip().split('\t')
55 | raters_key[k] = v
56 | # reverse the mapping
57 | raters_key = {v: k for k, v in raters_key.items()}
58 |
59 | evs = data.EvalSet(FLAGS.test_set, FLAGS.language_pair)
60 |
61 | evalset_ratings: dict[str, dict[str, list[ratings.Rating | None]]] = {}
62 | evalset_rater_ids: dict[str, dict[str, list[str | None]]] = {}
63 | for filename in FLAGS.evalset_ratings_files:
64 | if not filename: continue
65 | _, name, _ = evs.ParseHumanScoreFilename(
66 | os.path.basename(filename), rating_file=True)
67 | rating_name = name.rsplit('.', maxsplit=1)[-1]
68 | evalset_ratings[rating_name], evalset_rater_ids[rating_name] = (
69 | ratings.ReadRatingFile(filename, rating_name)
70 | )
71 |
72 | ratings_list = standalone_ratings.EvalSetRatingsToRatingsList(
73 | evalset_ratings, evs, evalset_rater_ids, raters_key)
74 |
75 | standalone_ratings.WriteRatingFile(ratings_list, FLAGS.ratings_file)
76 |
77 |
78 | if __name__ == '__main__':
79 | app.run(main)
80 |
--------------------------------------------------------------------------------
/mt_metrics_eval/converters/score_mqm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Produce MQM scores from MQM ratings tsv file."""
16 |
17 | import collections
18 | import csv
19 | import json
20 | from absl import app
21 | from absl import flags
22 | import glob
23 |
24 | flags.DEFINE_string('input', '/dev/stdin', 'Input MQM ratings tsv file.')
25 | flags.DEFINE_string('output', '/dev/stdout', 'Output MQM score file.')
26 | flags.DEFINE_string(
27 | 'weights', 'Major:5 Minor:1 Neutral:0 '
28 | 'Major/Non-translation!:25 Minor/Fluency/Punctuation:0.1',
29 | 'List of weight specs, in format: "severity[/category[/subcategory]]:wt". '
30 | 'The most specific match is applied to each error.')
31 | flags.DEFINE_string(
32 | 'weights_sep', ' ', 'Separator character between items in weights lists.')
33 | flags.DEFINE_bool('unbabel', False, 'Input tsv is in Unbabel format.')
34 | flags.DEFINE_bool(
35 | 'recompute_unbabel', False,
36 | 'Apply Google-style weights to Unbabel ratings rather than reading scores '
37 | 'directly from mqm field in last column of tsv.')
38 | flags.DEFINE_bool(
39 | 'force_contiguous', True,
40 | 'Raise an error if annotated segments within a doc aren\'t contiguous')
41 | flags.DEFINE_string(
42 | 'doc_id', 'doc_id',
43 | 'Name of field containing 1-based id of segment within document')
44 |
45 | FLAGS = flags.FLAGS
46 |
47 |
48 | def Score(weights, items):
49 | items = [x.lower() for x in items]
50 | while items:
51 | if '/'.join(items) in weights:
52 | return weights['/'.join(items)]
53 | items = items[:-1]
54 | return 0
55 |
56 |
57 | def main(argv):
58 | if len(argv) > 1:
59 | raise app.UsageError('Too many command-line arguments.')
60 |
61 | weights = {}
62 | for e in FLAGS.weights.split(FLAGS.weights_sep):
63 | c, w = e.split(':')
64 | weights[c.lower()] = float(w)
65 |
66 | scores = {} # sys -> doc > doc_id -> rater -> [score]
67 | quoting = csv.QUOTE_MINIMAL if FLAGS.unbabel else csv.QUOTE_NONE
68 | with open(FLAGS.input) as f:
69 | for row in csv.DictReader(f, delimiter='\t', quoting=quoting):
70 | system, doc, doc_id = row['system'], row['doc'], int(row[FLAGS.doc_id])
71 | if FLAGS.unbabel and not FLAGS.recompute_unbabel:
72 | score = json.loads(row['misc'])['mqm']
73 | else:
74 | score = Score(weights, [row['severity']] + row['category'].split('/'))
75 | if system not in scores:
76 | scores[system] = {}
77 | if doc not in scores[system]:
78 | scores[system][doc] = {}
79 | if doc_id not in scores[system][doc]:
80 | scores[system][doc][doc_id] = collections.defaultdict(list)
81 | scores[system][doc][doc_id][row['rater']].append(score)
82 |
83 | if FLAGS.force_contiguous:
84 | for system in scores:
85 | for doc in scores[system]:
86 | ids = sorted(scores[system][doc])
87 | if ids != list(range(min(ids), max(ids) + 1)):
88 | raise ValueError(f'Non-contiguous segments for {system}/{doc}')
89 |
90 | with open(FLAGS.output, 'w') as f:
91 | for system in scores:
92 | for doc in scores[system]:
93 | for doc_id in sorted(scores[system][doc]):
94 | rater_scores = {}
95 | for rater, vals in scores[system][doc][doc_id].items():
96 | if FLAGS.unbabel and not FLAGS.recompute_unbabel:
97 | rater_scores[rater] = sum(vals) / len(vals)
98 | else:
99 | rater_scores[rater] = sum(vals)
100 | global_score = sum(rater_scores.values()) / len(rater_scores)
101 | if not FLAGS.unbabel or FLAGS.recompute_unbabel:
102 | global_score *= -1
103 | f.write(f'{system}\t{doc}\t{doc_id}\t{global_score}')
104 | for rater in sorted(rater_scores):
105 | f.write(f'\t{rater}={rater_scores[rater]}')
106 | f.write('\n')
107 |
108 | if __name__ == '__main__':
109 | app.run(main)
110 |
--------------------------------------------------------------------------------
/mt_metrics_eval/converters/standalone_ratings_to_evalset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert standalone ratings file to EvalSet format files."""
16 |
17 | import os
18 | from absl import app
19 | from absl import flags
20 | from mt_metrics_eval import data
21 | from mt_metrics_eval import ratings
22 | from mt_metrics_eval import standalone_ratings
23 | import glob
24 |
25 | flags.DEFINE_multi_string(
26 | 'ratings_file', None,
27 | 'Ratings jsonl file to be converted.', required=True)
28 | flags.DEFINE_string(
29 | 'test_set', None, 'Test set, eg wmt20.', required=True)
30 | flags.DEFINE_string(
31 | 'language_pair', None, 'Language pair, eg en-de.', required=True)
32 | flags.DEFINE_string(
33 | 'output_dir', None,
34 | 'Directory in which to write output files.', required=True)
35 | flags.DEFINE_string(
36 | 'prefix', '',
37 | 'Prefix for output files. Full name is {prefix}{rater}.seg.rating.')
38 | flags.DEFINE_bool(
39 | 'anonymize_raters', False, 'Anonymize rater names.')
40 | flags.DEFINE_bool(
41 | 'merge_raters', False,
42 | 'By default, conversion produces a separate rating file for each rater, '
43 | 'even when raters annotate disjoint sets of items. This option will write '
44 | 'only a single file {prefix}merged.seg.rating by merging contributions '
45 | 'when possible, ie when rater contributions are disjoint; otherwise it '
46 | 'will write separate files. Note that it is not possible to recover '
47 | 'original rater names with this option.')
48 | flags.DEFINE_bool(
49 | 'strict', True, 'Ensure text-level matches with the EvalSet.')
50 | flags.DEFINE_string(
51 | 'rater_key_file', None,
52 | 'Write rater rename key to this file, with entries of the form '
53 | 'old-name\tnew-name. New names are identical to old names unless '
54 | 'anonymize_raters is True or the original rater names are None.')
55 | flags.DEFINE_string(
56 | 'echo_ratings_file', None,
57 | 'Write ratings in standalone format to this file. These many not be '
58 | 'identical to the orginal entries due to dropping some fields and changing '
59 | 'field order.')
60 |
61 | FLAGS = flags.FLAGS
62 |
63 |
64 | def main(argv):
65 | if len(argv) > 1:
66 | raise app.UsageError('Too many command-line arguments.')
67 |
68 | ratings_list = []
69 | for ratings_file in FLAGS.ratings_file:
70 | ratings_list.extend(standalone_ratings.ReadRatingFile(ratings_file))
71 | evs = data.EvalSet(FLAGS.test_set, FLAGS.language_pair)
72 | ratings_dict, raters_key, rater_ids_dict = (
73 | standalone_ratings.RatingsListToEvalSetRatings(
74 | ratings_list, evs, FLAGS.anonymize_raters, FLAGS.strict
75 | )
76 | )
77 |
78 | if FLAGS.echo_ratings_file:
79 | standalone_ratings.WriteRatingFile(ratings_list, FLAGS.echo_ratings_file)
80 |
81 | if FLAGS.merge_raters:
82 | new_ratings, new_rater_ids = standalone_ratings.MergeEvalSetRaters(
83 | ratings_dict, evs, rater_ids_dict
84 | )
85 | if new_ratings:
86 | ratings_dict = {'merged': new_ratings}
87 | rater_ids_dict = {'merged': new_rater_ids}
88 |
89 | for rater, evs_ratings in ratings_dict.items():
90 | filename = os.path.join(
91 | FLAGS.output_dir, f'{FLAGS.prefix}{rater}.seg.rating')
92 | ratings.WriteRatingFile(evs_ratings, filename, rater_ids_dict[rater])
93 |
94 | if FLAGS.rater_key_file:
95 | with open(FLAGS.rater_key_file, 'w') as f:
96 | for rater, new_rater in raters_key.items():
97 | f.write(f'{rater}\t{new_rater}\n')
98 |
99 |
100 | if __name__ == '__main__':
101 | app.run(main)
102 |
--------------------------------------------------------------------------------
/mt_metrics_eval/converters/verify_scores_file.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Verify human and metric scores files, optionally repair metrics files."""
16 |
17 | import os
18 | from absl import app
19 | from absl import flags
20 | from mt_metrics_eval import data
21 | import glob
22 |
23 | flags.DEFINE_string(
24 | 'scores_file', None,
25 | 'Scores file to be verified. If not supplied, check all scores files for '
26 | 'given test_set, doing on-the-fly-repair, ie only report non-repairable '
27 | 'errors.')
28 | flags.DEFINE_bool('human_scores', False, 'File contains human scores.')
29 | flags.DEFINE_string(
30 | 'data_dir', None, 'Optional root directory for mt_metrics_eval data.')
31 | flags.DEFINE_string(
32 | 'test_set', None,
33 | 'Name of test_set to which metric pertains.', required=True)
34 | flags.DEFINE_string(
35 | 'language_pair', None,
36 | 'Language pair, must exist for test_set.', required=True)
37 | flags.DEFINE_string(
38 | 'repair', None,
39 | 'Write a repaired version of scores_file to this file. This will be a '
40 | 'verbatim copy if scores_file is correct. No action if --human_scores is '
41 | 'set.')
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | def main(argv):
47 | if len(argv) > 1:
48 | raise app.UsageError('Too many command-line arguments.')
49 |
50 | if FLAGS.scores_file:
51 | scores_file = os.path.basename(FLAGS.scores_file)
52 | else:
53 | scores_file = None
54 | read_all_scores = scores_file is None
55 | evs = data.EvalSet(FLAGS.test_set, FLAGS.language_pair,
56 | read_stored_metric_scores=read_all_scores,
57 | path=FLAGS.data_dir,
58 | strict=False)
59 |
60 | if read_all_scores:
61 | return
62 |
63 | # Check filename conventions, fail with error if incorrect.
64 | if FLAGS.human_scores:
65 | lp, name, level = evs.ParseHumanScoreFilename(scores_file)
66 | if lp != FLAGS.language_pair:
67 | raise ValueError(
68 | f'Language pair {lp} from scores file doesn\'t match flag.')
69 | else:
70 | name, level = evs.ParseMetricFilename(scores_file)
71 | evs.ParseMetricName(name)
72 |
73 | # Check contents, optionally repair missing-system errors.
74 | scores_map = data.ReadScoreFile(FLAGS.scores_file)
75 | added = evs.CheckScores(
76 | scores_map, name, level, FLAGS.human_scores, FLAGS.repair)
77 | if added:
78 | print(f'Added dummy scores (0s) for missing outputs: {added}')
79 |
80 | if FLAGS.repair and not FLAGS.human_scores:
81 | with open(FLAGS.repair, 'w') as f:
82 | for sysname, scores in scores_map.items():
83 | f.write('\n'.join([f'{sysname}\t{s}' for s in scores]) + '\n')
84 |
85 |
86 | if __name__ == '__main__':
87 | app.run(main)
88 |
--------------------------------------------------------------------------------
/mt_metrics_eval/data_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tests for data module."""
16 |
17 | from mt_metrics_eval import data
18 | from mt_metrics_eval import meta_info
19 | from mt_metrics_eval import ratings
20 | import numpy as np
21 | import unittest
22 |
23 |
24 | class EvalSetTest(unittest.TestCase):
25 |
26 | def _std_sys_names(self, evs):
27 | return evs.sys_names - evs.human_sys_names - evs.outlier_sys_names
28 |
29 | def testWMT20EnDeSysCorrelations(self):
30 | evs = data.EvalSet('wmt20', 'en-de', True)
31 | # Spot-checking table 6 in www.statmt.org/wmt20/pdf/2020.wmt-1.77.pdf
32 | results = {}
33 | sys_names = self._std_sys_names(evs)
34 | gold_scores = evs.Scores('sys', evs.StdHumanScoreName('sys'))
35 | for m in 'BLEU', 'sentBLEU', 'COMET', 'BLEURT-extended', 'prism', 'YiSi-0':
36 | metric_scores = evs.Scores('sys', m + '-ref')
37 | results[m] = evs.Correlation(
38 | gold_scores, metric_scores, sys_names).Pearson()[0]
39 | self.assertAlmostEqual(results['BLEU'], 0.825, places=3)
40 | self.assertAlmostEqual(results['sentBLEU'], 0.823, places=3)
41 | self.assertAlmostEqual(results['COMET'], 0.863, places=3)
42 | self.assertAlmostEqual(results['BLEURT-extended'], 0.870, places=3)
43 | self.assertAlmostEqual(results['prism'], 0.851, places=3)
44 | self.assertAlmostEqual(results['YiSi-0'], 0.889, places=3)
45 |
46 | # Spot-checking table 7 in www.statmt.org/wmt20/pdf/2020.wmt-1.77.pdf, 3rd
47 | # column.
48 | sys_names.add('ref')
49 | for m in 'BLEU', 'sentBLEU', 'COMET', 'BLEURT-extended', 'prism', 'YiSi-0':
50 | variant_name = m + '-refb'
51 | metric_scores = evs.Scores('sys', variant_name)
52 | results[variant_name] = evs.Correlation(
53 | gold_scores, metric_scores, sys_names).Pearson()[0]
54 | self.assertAlmostEqual(results['BLEU-refb'], 0.672, places=3)
55 | self.assertAlmostEqual(results['sentBLEU-refb'], 0.639, places=3)
56 | self.assertAlmostEqual(results['COMET-refb'], 0.879, places=3)
57 | self.assertAlmostEqual(results['BLEURT-extended-refb'], 0.883, places=3)
58 | self.assertAlmostEqual(results['prism-refb'], 0.731, places=3)
59 | self.assertAlmostEqual(results['YiSi-0-refb'], 0.728, places=3)
60 |
61 | def testWMT20EnDeDocCorrelations(self):
62 | evs = data.EvalSet('wmt20', 'en-de', True)
63 | # Spot-checking table 12 in www.statmt.org/wmt20/pdf/2020.wmt-1.77.pdf, 4th
64 | # column (numbers do not match the table, ones here are correct).
65 | results = {}
66 | sys_names = self._std_sys_names(evs)
67 | gold_scores = evs.Scores('doc', evs.StdHumanScoreName('doc'))
68 | for m in 'sentBLEU', 'COMET', 'BLEURT-extended', 'prism', 'YiSi-0':
69 | metric_scores = evs.Scores('doc', m + '-ref')
70 | corr = evs.Correlation(gold_scores, metric_scores, sys_names)
71 | c, _, num_pairs = corr.KendallLike()
72 | self.assertEqual(num_pairs, 275)
73 | results[m] = c
74 | self.assertAlmostEqual(results['sentBLEU'], 0.411, places=3)
75 | self.assertAlmostEqual(results['COMET'], 0.433, places=3)
76 | self.assertAlmostEqual(results['BLEURT-extended'], 0.396, places=3)
77 | self.assertAlmostEqual(results['prism'], 0.389, places=3)
78 | self.assertAlmostEqual(results['YiSi-0'], 0.360, places=3)
79 |
80 | def testWMT20EnDeSegCorrelations(self):
81 | evs = data.EvalSet('wmt20', 'en-de', True)
82 | # Spot-checking table 10 in www.statmt.org/wmt20/pdf/2020.wmt-1.77.pdf, 4th
83 | # column.
84 | results = {}
85 | sys_names = self._std_sys_names(evs)
86 | gold_scores = evs.Scores('seg', evs.StdHumanScoreName('seg'))
87 | for m in 'sentBLEU', 'COMET', 'BLEURT-extended', 'prism', 'YiSi-0':
88 | metric_scores = evs.Scores('seg', m + '-ref')
89 | corr = evs.Correlation(gold_scores, metric_scores, sys_names)
90 | c, _, num_pairs = corr.KendallLike()
91 | results[m] = c
92 | self.assertEqual(num_pairs, 4637)
93 | self.assertAlmostEqual(results['sentBLEU'], 0.155, places=3)
94 | self.assertAlmostEqual(results['COMET'], 0.324, places=3)
95 | self.assertAlmostEqual(results['BLEURT-extended'], 0.278, places=3)
96 | self.assertAlmostEqual(results['prism'], 0.280, places=3)
97 | self.assertAlmostEqual(results['YiSi-0'], 0.212, places=3)
98 |
99 | def testWMT20EnDeMQMScores(self):
100 | evs = data.EvalSet('wmt20', 'en-de')
101 | results = {}
102 | for level in 'sys', 'doc', 'seg':
103 | scores = evs.Scores(level, 'mqm')
104 | n = len(scores)
105 | self.assertEqual(n, 10)
106 | results[level] = (scores['OPPO.1535'][0], scores['OPPO.1535'][-1])
107 | self.assertAlmostEqual(results['sys'][0], -2.24805, places=5)
108 | self.assertAlmostEqual(results['sys'][1], -2.24805, places=5)
109 | self.assertAlmostEqual(results['doc'][0], -1.55128, places=5)
110 | self.assertAlmostEqual(results['doc'][1], -1.26429, places=5)
111 | self.assertAlmostEqual(results['seg'][0], -2.66667, places=5)
112 | self.assertAlmostEqual(results['seg'][1], -1.33333, places=5)
113 |
114 | def testWMT20EnDePSQMScores(self):
115 | evs = data.EvalSet('wmt20', 'en-de')
116 | results = {}
117 | for level in 'sys', 'doc', 'seg':
118 | scores = evs.Scores(level, 'psqm')
119 | n = len(scores)
120 | self.assertEqual(n, 10)
121 | results[level] = (scores['OPPO.1535'][0], scores['OPPO.1535'][-1])
122 | self.assertAlmostEqual(results['sys'][0], 3.78561, places=5)
123 | self.assertAlmostEqual(results['sys'][1], 3.78561, places=5)
124 | self.assertAlmostEqual(results['doc'][0], 4.41026, places=5)
125 | self.assertAlmostEqual(results['doc'][1], 4.38095, places=5)
126 | self.assertAlmostEqual(results['seg'][0], 4.0, places=5)
127 | self.assertAlmostEqual(results['seg'][1], 4.66667, places=5)
128 |
129 | def testWMT20EnDeCSQMScores(self):
130 | evs = data.EvalSet('wmt20', 'en-de')
131 | results = {}
132 | for level in 'sys', 'doc', 'seg':
133 | scores = evs.Scores(level, 'csqm')
134 | if scores:
135 | n = len(scores)
136 | self.assertEqual(n, 10)
137 | results[level] = (scores['OPPO.1535'][0], scores['OPPO.1535'][-1])
138 | if results:
139 | self.assertAlmostEqual(results['sys'][0], 5.02116, places=5)
140 | self.assertAlmostEqual(results['sys'][1], 5.02116, places=5)
141 | self.assertAlmostEqual(results['doc'][0], 4.71795, places=5)
142 | self.assertAlmostEqual(results['doc'][1], 5.66667, places=5)
143 | self.assertAlmostEqual(results['seg'][0], 5.00000, places=5)
144 | self.assertAlmostEqual(results['seg'][1], 6.00000, places=5)
145 |
146 | def testWMT20SentBLEUSysScores(self):
147 | # All sentBLEU results from tables 5 and 6 in
148 | # www.statmt.org/wmt20/pdf/2020.wmt-1.77.pdf. (full, no-outlier)
149 | expected = {
150 | 'cs-en': (0.844, 0.800),
151 | 'de-en': (0.978, 0.786),
152 | 'en-cs': (0.840, 0.436),
153 | 'en-de': (0.934, 0.823),
154 | 'en-iu': (0.129, 0.047),
155 | 'en-ja': (0.946, 0.976),
156 | 'en-pl': (0.950, 0.772),
157 | 'en-ru': (0.981, 0.981),
158 | 'en-ta': (0.881, 0.852),
159 | 'en-zh': (0.927, 0.927),
160 | 'iu-en': (0.649, 0.469),
161 | 'ja-en': (0.974, 0.851),
162 | 'km-en': (0.969, 0.969),
163 | 'pl-en': (0.502, 0.284),
164 | 'ps-en': (0.888, 0.888),
165 | 'ru-en': (0.916, 0.833),
166 | 'ta-en': (0.925, 0.829),
167 | 'zh-en': (0.948, 0.950),
168 | }
169 | for lp in meta_info.DATA['wmt20']:
170 | evs = data.EvalSet('wmt20', lp, True)
171 | all_sys = evs.sys_names - evs.human_sys_names
172 | gold_scores = evs.Scores('sys', evs.StdHumanScoreName('sys'))
173 | sent_bleu = evs.Scores('sys', 'sentBLEU-ref')
174 | pearson_full = evs.Correlation(gold_scores, sent_bleu, all_sys).Pearson()
175 | pearson_no_outlier = evs.Correlation(
176 | gold_scores, sent_bleu, all_sys - evs.outlier_sys_names).Pearson()
177 | self.assertAlmostEqual(pearson_full[0], expected[lp][0], places=3)
178 | self.assertAlmostEqual(pearson_no_outlier[0], expected[lp][1], places=3)
179 |
180 | def testWMT20SentBLEUSegScores(self):
181 | # All sentBLEU results from tables 9 and 10 in
182 | # www.statmt.org/wmt20/pdf/2020.wmt-1.77.pdf (full, no-outlier)
183 | expected = {
184 | 'cs-en': (0.068, 0.057),
185 | 'de-en': (0.413, -0.025),
186 | 'en-cs': (0.432, 0.194),
187 | 'en-de': (0.303, 0.155),
188 | 'en-iu': (0.206, -0.084),
189 | 'en-ja': (0.480, 0.390), # Numbers in the table don't include outliers.
190 | 'en-pl': (0.153, 0.067),
191 | 'en-ru': (0.051, 0.051),
192 | 'en-ta': (0.398, 0.206),
193 | 'en-zh': (0.396, 0.396),
194 | 'iu-en': (0.182, 0.170),
195 | 'ja-en': (0.188, 0.061),
196 | 'km-en': (0.226, 0.226),
197 | 'pl-en': (-0.024, -0.046),
198 | 'ps-en': (0.096, 0.096),
199 | 'ru-en': (-0.005, -0.038),
200 | 'ta-en': (0.162, 0.069),
201 | 'zh-en': (0.093, 0.060),
202 | }
203 | for lp in meta_info.DATA['wmt20']:
204 | evs = data.EvalSet('wmt20', lp, True)
205 | all_sys = evs.sys_names - evs.human_sys_names
206 | gold_scores = evs.Scores('seg', evs.StdHumanScoreName('seg'))
207 | sent_bleu = evs.Scores('seg', 'sentBLEU-ref')
208 | kendall_full = evs.Correlation(
209 | gold_scores, sent_bleu, all_sys).KendallLike()
210 | kendall_no_outlier = evs.Correlation(
211 | gold_scores, sent_bleu, all_sys - evs.outlier_sys_names).KendallLike()
212 | self.assertAlmostEqual(kendall_full[0], expected[lp][0], places=3)
213 | self.assertAlmostEqual(kendall_no_outlier[0], expected[lp][1], places=3)
214 |
215 | def testWMT19BLEUSysScores(self):
216 | # All sys-level BLEU results from tables 3, 4, 5, in
217 | # https://www.aclweb.org/anthology/W19-5302.pdf
218 | expected = {
219 | 'de-cs': 0.941,
220 | 'de-en': 0.849,
221 | 'de-fr': 0.891,
222 | 'en-cs': 0.897,
223 | 'en-de': 0.921,
224 | 'en-fi': 0.969,
225 | 'en-gu': 0.737,
226 | 'en-kk': 0.852,
227 | 'en-lt': 0.989,
228 | 'en-ru': 0.986,
229 | 'en-zh': 0.901,
230 | 'fi-en': 0.982,
231 | 'fr-de': 0.864,
232 | 'gu-en': 0.834,
233 | 'kk-en': 0.946,
234 | 'lt-en': 0.961,
235 | 'ru-en': 0.879,
236 | 'zh-en': 0.899,
237 | }
238 | for lp in meta_info.DATA['wmt19']:
239 | evs = data.EvalSet('wmt19', lp, True)
240 | gold_scores = evs.Scores('sys', evs.StdHumanScoreName('sys'))
241 | bleu_scores = evs.Scores('sys', 'BLEU-ref')
242 | # Need to filter here because not all lps have wmt-z score for all
243 | # systems.
244 | sys_names = self._std_sys_names(evs).intersection(gold_scores)
245 | pearson = evs.Correlation(gold_scores, bleu_scores, sys_names).Pearson()
246 | self.assertAlmostEqual(pearson[0], expected[lp], places=3)
247 |
248 | def testWMT19BEERSegScores(self):
249 | # All seg-level BEER results from tables 6, 7, 8 in
250 | # https://www.aclweb.org/anthology/W19-5302.pdf
251 | expected = {
252 | 'de-cs': 0.337,
253 | 'de-en': 0.128,
254 | 'de-fr': 0.293,
255 | 'en-cs': 0.443,
256 | 'en-de': 0.316,
257 | 'en-fi': 0.514,
258 | 'en-gu': 0.537,
259 | 'en-kk': 0.516,
260 | 'en-lt': 0.441,
261 | 'en-ru': 0.542,
262 | 'en-zh': 0.232,
263 | 'fi-en': 0.283,
264 | 'fr-de': 0.265,
265 | 'gu-en': 0.260,
266 | 'kk-en': 0.421,
267 | 'lt-en': 0.315,
268 | 'ru-en': 0.189,
269 | 'zh-en': 0.371,
270 | }
271 | for lp in meta_info.DATA['wmt19']:
272 | evs = data.EvalSet('wmt19', lp, True)
273 | gold_scores = evs.Scores('seg', evs.StdHumanScoreName('seg'))
274 | beer_scores = evs.Scores('seg', 'BEER-ref')
275 | # Need to filter here because not all lps have wmt-z score for all
276 | # systems.
277 | sys_names = self._std_sys_names(evs).intersection(gold_scores)
278 | kl = evs.Correlation(gold_scores, beer_scores, sys_names).KendallLike()
279 | self.assertAlmostEqual(kl[0], expected[lp], places=3)
280 |
281 | def testWMT23EnDeRatings(self):
282 | evs = data.EvalSet('wmt23', 'en-de', read_stored_ratings=True)
283 | self.assertEqual(evs.human_rating_names, {'mqm.merged'})
284 |
285 | expected_error = ratings.Error(
286 | start=0,
287 | end=23,
288 | category='style/unnatural or awkward',
289 | severity='minor',
290 | score=1.0,
291 | is_source_error=False
292 | )
293 | expected = ratings.Rating([expected_error])
294 | self.assertEqual(evs.Ratings('mqm.merged')['AIRC'][5], expected)
295 | self.assertEqual(evs.RaterIdsPerSeg('mqm.merged')['AIRC'][5], 'rater3')
296 |
297 | def testWMT23SentEnDeRatings(self):
298 | evs = data.EvalSet('wmt23.sent', 'en-de', read_stored_ratings=True)
299 | self.assertEqual(evs.human_rating_names, {'mqm.merged'})
300 |
301 | expected_error = ratings.Error(
302 | start=212,
303 | end=237,
304 | category='accuracy/addition',
305 | severity='major',
306 | score=5.0,
307 | is_source_error=False
308 | )
309 | expected = ratings.Rating([expected_error])
310 | self.assertEqual(evs.Ratings('mqm.merged')['ONLINE-Y'][2], expected)
311 | # TODO(dandeutsch): Update this test to check for the actual rater ID
312 | # once it has been propagated to the sentence-level file.
313 | self.assertEqual(
314 | evs.RaterIdsPerSeg('mqm.merged')['ONLINE-Y'][2], 'mqm.merged'
315 | )
316 |
317 | def testWMT23ZhEnRatings(self):
318 | evs = data.EvalSet('wmt23', 'zh-en', read_stored_ratings=True)
319 | self.assertEqual(
320 | evs.human_rating_names,
321 | {f'mqm.rater{i}' for i in range(1, 9)})
322 |
323 | expected_error = ratings.Error(
324 | start=148,
325 | end=149,
326 | category='fluency/punctuation',
327 | severity='minor',
328 | score=0.1,
329 | is_source_error=False
330 | )
331 | expected = ratings.Rating([expected_error])
332 | self.assertEqual(evs.Ratings('mqm.rater1')['GPT4-5shot'][318], expected)
333 |
334 |
335 | class DataTest(unittest.TestCase):
336 |
337 | def testAssignRanks(self):
338 |
339 | sig_matrix = np.array([
340 | 0, 1, 1, 1,
341 | 0, 0, 1, 1,
342 | 0, 0, 0, 1,
343 | 0, 0, 0, 0]).reshape((4, 4))
344 | self.assertEqual(data.AssignRanks(sig_matrix, 0.5), [1, 1, 1, 1])
345 | self.assertEqual(data.AssignRanks(sig_matrix, 1.0), [1, 2, 3, 4])
346 |
347 | sig_matrix = np.array([
348 | 0, 1, 0, 1,
349 | 0, 0, 1, 1,
350 | 0, 0, 0, 1,
351 | 0, 0, 0, 0]).reshape((4, 4))
352 | self.assertEqual(data.AssignRanks(sig_matrix, 0.5), [1, 1, 2, 2])
353 |
354 | sig_matrix = np.array([
355 | 0, 0, 1, 1,
356 | 0, 0, 1, 1,
357 | 0, 0, 0, 0,
358 | 0, 0, 0, 0]).reshape((4, 4))
359 | self.assertEqual(data.AssignRanks(sig_matrix, 0.5), [1, 2, 2, 3])
360 |
361 | def testMapPositions(self):
362 | items = ['aa', 'aa', 'bb', 'bb', 'bb', 'aa']
363 | d = data._MapPositions(items)
364 | self.assertEqual(d['aa'], [[0, 2], [5, 6]])
365 | self.assertEqual(d['bb'], [[2, 5]])
366 | self.assertEqual(data._UnmapPositions(d), items)
367 |
368 | items = ['aa', 'aa', 'bb', 'cc', 'cc', 'cc']
369 | d = data._MapPositions(items, True)
370 | self.assertEqual(d['aa'], [0, 2])
371 | self.assertEqual(d['bb'], [2, 3])
372 | self.assertEqual(d['cc'], [3, 6])
373 | self.assertEqual(data._UnmapPositions(d, True), items)
374 |
375 |
376 | if __name__ == '__main__':
377 | unittest.main()
378 |
--------------------------------------------------------------------------------
/mt_metrics_eval/meta_info.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Meta-information for standard datasets."""
16 |
17 | import dataclasses
18 |
19 |
20 | @dataclasses.dataclass
21 | class MetaInfo:
22 | """Meta information for test-sets and language pairs."""
23 | std_ref: str
24 | std_gold: dict[str, str] # Map level to name of human gold scores.
25 | outlier_systems: set[str]
26 | # Base names (not including -reference extensions) of metrics considered to be
27 | # primary submissions, or baselines like BLEU. When primary submissions can
28 | # include both reference-based and reference-free versions, these must have
29 | # distinct basenames, eg MyMetric and MyMetric-QE.
30 | primary_metrics: set[str]
31 | # For backward compability, baselines should be a subset of primary metrics.
32 | baseline_metrics: set[str] | None = None
33 |
34 | WMT19 = MetaInfo('ref', {'sys': 'wmt-z', 'seg': 'wmt-raw'}, set(), set())
35 | WMT20 = MetaInfo('ref', {'sys': 'wmt-z', 'doc': 'wmt-raw', 'seg': 'wmt-raw'},
36 | set(), set())
37 | WMT21_PRIMARIES = {
38 | 'bleurt-20', 'COMET-MQM_2021',
39 | 'COMET-QE-MQM_2021', 'C-SPECpn', 'MEE2',
40 | 'MTEQA', 'OpenKiwi-MQM', 'regEMT', 'YiSi-1',
41 | 'YiSi-2', 'BERTScore', 'sentBLEU', 'BLEU', 'chrF', 'Prism', 'TER'
42 | }
43 | WMT21 = MetaInfo('refA', {'sys': 'wmt-z', 'seg': 'wmt-raw'}, set(),
44 | WMT21_PRIMARIES)
45 |
46 | WMT22_PRIMARIES = {
47 | 'BERTScore', 'BLEURT-20', 'BLEU', 'chrF', 'COMET-20', 'COMET-22',
48 | 'COMETKiwi', 'COMET-QE', 'f200spBLEU',
49 | 'HuaweiTSC_EE_BERTScore_0.3_With_Human', 'HWTSC-Teacher-Sim', 'MATESE-QE',
50 | 'MATESE', 'MEE4', 'metricx_xxl_MQM_2020', 'MS-COMET-22', 'MS-COMET-QE-22',
51 | 'REUSE', 'SEScore', 'UniTE', 'UniTE-src', 'YiSi-1'
52 | }
53 | WMT22 = MetaInfo(
54 | 'refA',
55 | {'sys': 'mqm', 'domain': 'mqm', 'seg': 'mqm'}, set(),
56 | WMT22_PRIMARIES)
57 | WMT22_DA = MetaInfo(
58 | 'refA',
59 | {'sys': 'wmt', 'domain': 'wmt', 'seg': 'wmt'}, set(),
60 | WMT22_PRIMARIES)
61 | WMT22_DA_NODOMAIN = MetaInfo(
62 | 'refA',
63 | {'sys': 'wmt', 'seg': 'wmt'}, set(),
64 | WMT22_PRIMARIES)
65 | WMT22_APPRAISE = MetaInfo(
66 | 'refA',
67 | {'sys': 'wmt-appraise', 'domain': 'wmt-appraise', 'seg': 'wmt-appraise'},
68 | set(), WMT22_PRIMARIES)
69 | WMT22_NODOMAIN = MetaInfo(
70 | 'refA',
71 | {'sys': 'wmt-appraise', 'seg': 'wmt-appraise'},
72 | set(), WMT22_PRIMARIES)
73 |
74 | WMT23_PRIMARIES = {
75 | 'Calibri-COMET22', 'Calibri-COMET22-QE', 'cometoid22-wmt22', 'eBLEU',
76 | 'embed_llama', 'GEMBA-MQM', 'KG-BERTScore', 'MaTESe', 'MEE4', 'MetricX-23',
77 | 'MetricX-23-QE', 'mre-score-labse-regular', 'mbr-metricx-qe', 'sescoreX',
78 | 'tokengram_F', 'XCOMET-Ensemble', 'XCOMET-QE-Ensemble', 'XLsim',
79 | 'BERTscore', 'BLEU', 'BLEURT-20', 'chrF', 'COMET', 'CometKiwi',
80 | 'docWMT22CometDA', 'docWMT22CometKiwiDA', 'f200spBLEU', 'MS-COMET-QE-22',
81 | 'prismRef', 'prismSrc', 'Random-sysname', 'YiSi-1'
82 | }
83 |
84 | WMT23_BASELINES = {
85 | 'BERTscore', 'BLEU', 'BLEURT-20', 'chrF', 'COMET', 'CometKiwi',
86 | 'docWMT22CometDA', 'docWMT22CometKiwiDA', 'f200spBLEU', 'MS-COMET-QE-22',
87 | 'prismRef', 'prismSrc', 'Random-sysname', 'YiSi-1'
88 | }
89 |
90 | WMT23 = MetaInfo(
91 | 'refA',
92 | {'sys': 'mqm', 'domain': 'mqm', 'seg': 'mqm'},
93 | set(), WMT23_PRIMARIES, WMT23_BASELINES)
94 |
95 | WMT23_DA = MetaInfo(
96 | 'refA',
97 | {'sys': 'da-sqm', 'domain': 'da-sqm', 'seg': 'da-sqm'},
98 | set(), WMT23_PRIMARIES, WMT23_BASELINES)
99 |
100 | WMT24_PRIMARIES = {
101 | 'BLCOM_1', 'bright-qe', 'chrfS', 'damonmonli', 'gemba_esa', 'MEE4',
102 | 'metametrics_mt_mqm_hybrid_kendall', 'metametrics_mt_mqm_qe_kendall.seg.s',
103 | 'MetricX-24-Hybrid', 'MetricX-24-Hybrid-QE', 'XCOMET', 'XCOMET-QE',
104 | 'XLsimMqm', 'BERTScore', 'BLEU', 'BLEURT-20', 'chrF', 'COMET-22',
105 | 'CometKiwi', 'PrismRefMedium', 'PrismRefSmall', 'spBLEU', 'YiSi-1',
106 | 'sentinel-cand-mqm', 'sentinel-ref-mqm', 'sentinel-src-mqm',
107 | }
108 |
109 | WMT24_BASELINES = {
110 | 'BERTScore', 'BLEU', 'BLEURT-20', 'chrF', 'COMET-22', 'CometKiwi',
111 | 'PrismRefMedium', 'PrismRefSmall', 'spBLEU', 'YiSi-1',
112 | 'sentinel-cand-mqm', 'sentinel-ref-mqm', 'sentinel-src-mqm',
113 | }
114 |
115 | WMT24 = MetaInfo(
116 | 'refA',
117 | {'sys': 'mqm', 'domain': 'mqm', 'seg': 'mqm'},
118 | {'MSLC'}, WMT24_PRIMARIES, WMT24_BASELINES)
119 |
120 | WMT24_ESA = MetaInfo(
121 | 'refA',
122 | {'sys': 'esa', 'domain': 'esa', 'seg': 'esa'},
123 | set(), WMT24_PRIMARIES, WMT24_BASELINES
124 | )
125 |
126 | WMT24PP_PRIMARIES = {
127 | 'BLEU', 'ChrF', 'MetricX-24', 'MetricX-24-QE', 'XCOMET', 'XCOMET-QE',
128 | 'COMETKiwi-23', 'Gemini-DA', 'Gemini-DA-QE'
129 | }
130 |
131 | WMT24PP = MetaInfo('posteditA', {}, set(), WMT24PP_PRIMARIES, None)
132 |
133 | WMT24PP_DATA = {
134 | 'en-ar_EG': WMT24PP,
135 | 'en-ar_SA': WMT24PP,
136 | 'en-bg_BG': WMT24PP,
137 | 'en-bn_IN': WMT24PP,
138 | 'en-ca_ES': WMT24PP,
139 | 'en-cs_CZ': WMT24PP,
140 | 'en-da_DK': WMT24PP,
141 | 'en-de_DE': dataclasses.replace(WMT24PP, std_ref='posteditB'),
142 | 'en-el_GR': WMT24PP,
143 | 'en-es_MX': WMT24PP,
144 | 'en-et_EE': WMT24PP,
145 | 'en-fa_IR': WMT24PP,
146 | 'en-fi_FI': WMT24PP,
147 | 'en-fil_PH': WMT24PP,
148 | 'en-fr_CA': WMT24PP,
149 | 'en-fr_FR': WMT24PP,
150 | 'en-gu_IN': WMT24PP,
151 | 'en-he_IL': WMT24PP,
152 | 'en-hi_IN': WMT24PP,
153 | 'en-hr_HR': WMT24PP,
154 | 'en-hu_HU': WMT24PP,
155 | 'en-id_ID': WMT24PP,
156 | 'en-is_IS': dataclasses.replace(WMT24PP, std_ref='refA'),
157 | 'en-it_IT': WMT24PP,
158 | 'en-ja_JP': WMT24PP,
159 | 'en-kn_IN': WMT24PP,
160 | 'en-ko_KR': WMT24PP,
161 | 'en-lt_LT': WMT24PP,
162 | 'en-lv_LV': WMT24PP,
163 | 'en-ml_IN': WMT24PP,
164 | 'en-mr_IN': WMT24PP,
165 | 'en-nl_NL': WMT24PP,
166 | 'en-no_NO': WMT24PP,
167 | 'en-pa_IN': WMT24PP,
168 | 'en-pl_PL': WMT24PP,
169 | 'en-pt_BR': WMT24PP,
170 | 'en-pt_PT': WMT24PP,
171 | 'en-ro_RO': WMT24PP,
172 | 'en-ru_RU': WMT24PP,
173 | 'en-sk_SK': WMT24PP,
174 | 'en-sl_SI': WMT24PP,
175 | 'en-sr_RS': WMT24PP,
176 | 'en-sv_SE': WMT24PP,
177 | 'en-sw_KE': WMT24PP,
178 | 'en-sw_TZ': WMT24PP,
179 | 'en-ta_IN': WMT24PP,
180 | 'en-te_IN': WMT24PP,
181 | 'en-th_TH': WMT24PP,
182 | 'en-tr_TR': WMT24PP,
183 | 'en-uk_UA': WMT24PP,
184 | 'en-ur_PK': WMT24PP,
185 | 'en-vi_VN': WMT24PP,
186 | 'en-zh_CN': WMT24PP,
187 | 'en-zh_TW': WMT24PP,
188 | 'en-zu_ZA': WMT24PP,
189 | }
190 |
191 |
192 | DATA = {
193 | 'wmt24pp': WMT24PP_DATA,
194 | 'wmt24': {
195 | 'en-de': dataclasses.replace(WMT24, std_ref='refB'),
196 | 'en-es': WMT24,
197 | 'ja-zh': WMT24,
198 | 'cs-uk': WMT24_ESA,
199 | 'en-cs': WMT24_ESA,
200 | 'en-hi': WMT24_ESA,
201 | 'en-is': WMT24_ESA,
202 | 'en-ja': WMT24_ESA,
203 | 'en-ru': WMT24_ESA,
204 | 'en-uk': WMT24_ESA,
205 | 'en-zh': WMT24_ESA,
206 | },
207 | 'wmt23.sent': {'en-de': WMT23},
208 | 'wmt23': {
209 | 'en-de': dataclasses.replace(WMT23, outlier_systems={'synthetic_ref'}),
210 | 'he-en': dataclasses.replace(WMT23, std_ref='refB'),
211 | 'zh-en': dataclasses.replace(WMT23, outlier_systems={'synthetic_ref'}),
212 | 'cs-uk': WMT23_DA,
213 | 'de-en': WMT23_DA,
214 | 'en-cs': WMT23_DA,
215 | 'en-he': dataclasses.replace(WMT23_DA, std_gold={}, std_ref='refB'),
216 | 'en-ja': WMT23_DA,
217 | 'en-ru': dataclasses.replace(WMT23_DA, std_gold={}),
218 | 'en-uk': dataclasses.replace(WMT23_DA, std_gold={}),
219 | 'en-zh': WMT23_DA,
220 | 'ja-en': WMT23_DA,
221 | 'ru-en': dataclasses.replace(WMT23_DA, std_gold={}),
222 | 'uk-en': dataclasses.replace(WMT23_DA, std_gold={}),
223 | },
224 | 'wmt22': {
225 | 'en-de': dataclasses.replace(WMT22, outlier_systems={'M2M100_1.2B-B4'}),
226 | 'en-ru': WMT22,
227 | 'zh-en': WMT22,
228 | 'cs-en': dataclasses.replace(WMT22_DA, std_ref='refB'),
229 | 'cs-uk': WMT22_NODOMAIN,
230 | 'de-en': WMT22_DA,
231 | 'de-fr': dataclasses.replace(WMT22_APPRAISE, std_gold={}),
232 | 'en-cs': dataclasses.replace(WMT22_APPRAISE, std_ref='refB'),
233 | 'en-hr': WMT22_APPRAISE,
234 | 'en-ja': WMT22_APPRAISE,
235 | 'en-liv': WMT22_NODOMAIN,
236 | 'en-uk': WMT22_APPRAISE,
237 | 'en-zh': WMT22_APPRAISE,
238 | 'fr-de': dataclasses.replace(WMT22_APPRAISE, std_gold={}),
239 | 'ja-en': WMT22_DA,
240 | 'liv-en': WMT22_NODOMAIN,
241 | 'ru-en': WMT22_DA,
242 | 'ru-sah': dataclasses.replace(WMT22_APPRAISE, std_gold={}),
243 | 'sah-ru': WMT22_NODOMAIN,
244 | 'uk-cs': WMT22_NODOMAIN,
245 | 'uk-en': WMT22_DA_NODOMAIN,
246 | },
247 | 'wmt21.news': {
248 | 'en-cs': WMT21,
249 | 'en-de': dataclasses.replace(
250 | WMT21,
251 | std_ref='refC',
252 | std_gold={'sys': 'mqm', 'seg': 'mqm'},
253 | primary_metrics=WMT21_PRIMARIES | {'cushLEPOR(LM)'}),
254 | 'en-ha': WMT21,
255 | 'en-is': WMT21,
256 | 'en-ja': WMT21,
257 | 'en-ru': dataclasses.replace(
258 | WMT21,
259 | std_ref='refA',
260 | std_gold={'sys': 'mqm', 'seg': 'mqm'},
261 | primary_metrics=WMT21_PRIMARIES | {'hLEPOR'}),
262 | 'en-zh': WMT21,
263 | 'cs-en': WMT21,
264 | 'de-en': WMT21,
265 | 'de-fr': WMT21,
266 | 'fr-de': WMT21,
267 | 'ha-en': WMT21,
268 | 'is-en': WMT21,
269 | 'ja-en': WMT21,
270 | 'ru-en': WMT21,
271 | 'zh-en': dataclasses.replace(
272 | WMT21,
273 | std_ref='refB',
274 | std_gold={'sys': 'mqm', 'seg': 'mqm'},
275 | primary_metrics=WMT21_PRIMARIES | {'cushLEPOR(LM)', 'RoBLEURT'}),
276 | },
277 | 'wmt21.tedtalks': {
278 | 'en-de': dataclasses.replace(
279 | WMT21,
280 | std_ref='refA',
281 | std_gold={'sys': 'mqm', 'seg': 'mqm'},
282 | primary_metrics=WMT21_PRIMARIES | {'cushLEPOR(LM)'}),
283 | 'en-ru': dataclasses.replace(
284 | WMT21,
285 | std_ref='refA',
286 | std_gold={'sys': 'mqm', 'seg': 'mqm'},
287 | primary_metrics=WMT21_PRIMARIES | {'hLEPOR'}),
288 | 'zh-en': dataclasses.replace(
289 | WMT21,
290 | std_ref='refB',
291 | std_gold={'sys': 'mqm', 'seg': 'mqm'},
292 | primary_metrics=WMT21_PRIMARIES | {'cushLEPOR(LM)', 'RoBLEURT'}),
293 | },
294 | 'wmt21.flores': {
295 | 'bn-hi': WMT21,
296 | 'hi-bn': WMT21,
297 | 'xh-zu': WMT21,
298 | 'zu-xh': WMT21,
299 | },
300 | 'wmt20': {
301 | 'cs-en': dataclasses.replace(
302 | WMT20,
303 | outlier_systems={'zlabs-nlp.1149', 'CUNI-DocTransformer.1457'}),
304 | 'de-en': dataclasses.replace(
305 | WMT20,
306 | outlier_systems={'yolo.1052', 'zlabs-nlp.1153',
307 | 'WMTBiomedBaseline.387'}),
308 | 'en-cs': dataclasses.replace(
309 | WMT20,
310 | outlier_systems={'zlabs-nlp.1151', 'Online-G.1555'}),
311 | 'en-de': dataclasses.replace(
312 | WMT20,
313 | outlier_systems={'zlabs-nlp.179', 'WMTBiomedBaseline.388',
314 | 'Online-G.1556'}),
315 | 'en-iu': dataclasses.replace(
316 | WMT20,
317 | outlier_systems={'UEDIN.1281', 'OPPO.722', 'UQAM_TanLe.521'}),
318 | 'en-ja': dataclasses.replace(
319 | WMT20,
320 | outlier_systems={'Online-G.1557', 'SJTU-NICT.370'}),
321 | 'en-pl': dataclasses.replace(
322 | WMT20,
323 | outlier_systems={'Online-Z.1634', 'zlabs-nlp.180',
324 | 'Online-A.1576'}),
325 | 'en-ru': WMT20,
326 | 'en-ta': dataclasses.replace(
327 | WMT20,
328 | outlier_systems={'TALP_UPC.1049', 'SJTU-NICT.386',
329 | 'Online-G.1561'}),
330 | 'en-zh': WMT20,
331 | 'iu-en': dataclasses.replace(
332 | WMT20,
333 | outlier_systems={'NiuTrans.1206', 'Facebook_AI.729'}),
334 | 'ja-en': dataclasses.replace(
335 | WMT20,
336 | outlier_systems={'Online-G.1564', 'zlabs-nlp.66', 'Online-Z.1640'}),
337 | 'km-en': WMT20,
338 | 'pl-en': dataclasses.replace(
339 | WMT20,
340 | outlier_systems={'zlabs-nlp.1162'}),
341 | 'ps-en': WMT20,
342 | 'ru-en': dataclasses.replace(
343 | WMT20,
344 | outlier_systems={'zlabs-nlp.1164'}),
345 | 'ta-en': dataclasses.replace(
346 | WMT20,
347 | outlier_systems={'Online-G.1568', 'TALP_UPC.192'}),
348 | 'zh-en': dataclasses.replace(
349 | WMT20,
350 | outlier_systems={'WMTBiomedBaseline.183'})
351 | },
352 | 'wmt19': {
353 | 'de-cs': WMT19,
354 | 'de-en': WMT19,
355 | 'de-fr': WMT19,
356 | 'en-cs': WMT19,
357 | 'en-de': WMT19,
358 | 'en-fi': WMT19,
359 | 'en-gu': WMT19,
360 | 'en-kk': WMT19,
361 | 'en-lt': WMT19,
362 | 'en-ru': WMT19,
363 | 'en-zh': WMT19,
364 | 'fi-en': WMT19,
365 | 'fr-de': WMT19,
366 | 'gu-en': WMT19,
367 | 'kk-en': WMT19,
368 | 'lt-en': WMT19,
369 | 'ru-en': WMT19,
370 | 'zh-en': WMT19,
371 | }
372 | }
373 |
--------------------------------------------------------------------------------
/mt_metrics_eval/mt_metrics_eval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "IGuFoP_Gq9X9"
7 | },
8 | "source": [
9 | "This is a demo colab for MTME. It assumes you have mt_metrics_eval installed on your runtime, and have downloaded the data onto that machine. Run the cells below in order."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "gH8o8UKmUhQ8"
16 | },
17 | "source": [
18 | "# Preliminaries"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {
25 | "cellView": "code",
26 | "id": "Cr0TM9EY7wOH"
27 | },
28 | "outputs": [],
29 | "source": [
30 | "# @title Imports\n",
31 | "\n",
32 | "import numpy as np\n",
33 | "import scipy.stats\n",
34 | "\n",
35 | "from mt_metrics_eval import meta_info\n",
36 | "from mt_metrics_eval import data\n",
37 | "from mt_metrics_eval import stats\n",
38 | "from mt_metrics_eval import tasks"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {
45 | "cellView": "code",
46 | "id": "GznnWylA8gwJ"
47 | },
48 | "outputs": [],
49 | "source": [
50 | "# @title Print all available evalsets\n",
51 | "\n",
52 | "for testset in meta_info.DATA:\n",
53 | " print(f'{testset}:', ' '.join(lp for lp in meta_info.DATA[testset]))\n"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "id": "OfiSe4Yt8sz3"
61 | },
62 | "outputs": [],
63 | "source": [
64 | "# @title Load data for WMT21 language pairs scored with MQM\n",
65 | "\n",
66 | "all_evs = {} # name/lp -\u003e evs\n",
67 | "for testset in meta_info.DATA:\n",
68 | " if not testset.startswith('wmt21'): continue\n",
69 | " for lp in meta_info.DATA[testset]:\n",
70 | " if 'mqm' in meta_info.DATA[testset][lp].std_gold.values():\n",
71 | " all_evs[f'{testset}/{lp}'] = data.EvalSet(testset, lp, True)\n",
72 | "\n",
73 | "print('\\n'.join(all_evs.keys()))"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "cellView": "code",
81 | "id": "jAuNRBA79Yai"
82 | },
83 | "outputs": [],
84 | "source": [
85 | "# @title Print summaries for all loaded evalsets\n",
86 | "\n",
87 | "print(f'{\"name\":\u003c20} segs sys metrics gold refs std')\n",
88 | "for name, evs in all_evs.items():\n",
89 | " nsegs = len(evs.src)\n",
90 | " nsys = len(evs.sys_names)\n",
91 | " nmetrics = len(evs.metric_basenames)\n",
92 | " gold = evs.StdHumanScoreName('sys')\n",
93 | " nrefs = len(evs.ref_names)\n",
94 | " std_ref = evs.std_ref\n",
95 | "\n",
96 | " print(f'{name:\u003c20} {nsegs:5d} {nsys:3d} {nmetrics:7d} '\n",
97 | " f'{gold:5} {nrefs:4d} {std_ref}')"
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {
103 | "id": "AW9Mda-jUpqh"
104 | },
105 | "source": [
106 | "# Comparing metrics"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "id": "GSwjndy3mQeo"
114 | },
115 | "outputs": [],
116 | "source": [
117 | "# @title Set up for comparing metrics\n",
118 | "\n",
119 | "# There are many different ways to evaluate the performance of MT metrics. The\n",
120 | "# most obvious question is what correlation statistic we should use to capture\n",
121 | "# the similarity between a vector of metric scores and a vector of gold scores\n",
122 | "# (human ratings). A less obvious question is where those vectors come from.\n",
123 | "# We'll defer the choice of correlation statistic to later cells, and begin\n",
124 | "# by setting some parameters that precisely define the vectors we're interested\n",
125 | "# in comparing.\n",
126 | "\n",
127 | "# Use all evalsets that we've loaded.\n",
128 | "evs_list = all_evs.values()\n",
129 | "\n",
130 | "# Choose the version of each metric that uses the standard reference for each\n",
131 | "# evalset.\n",
132 | "main_refs = [{evs.std_ref} for evs in evs_list]\n",
133 | "\n",
134 | "# Some alternative references are known to be close to the standard reference.\n",
135 | "# Don't include these among systems to be scored if we are including 'human'\n",
136 | "# systems. The only currently known instance is refB in wmt21.news/en-de,\n",
137 | "# which is similar to the standard refC.\n",
138 | "close_refs = [{'refB'} if k == 'wmt21.news/en-de' else set() for k in all_evs]\n",
139 | "\n",
140 | "# Include 'human' systems (ie, reference translations) among systems to be\n",
141 | "# scored. This can make the task more challenging, since some metrics are\n",
142 | "# biased against less literal references.\n",
143 | "include_human = True\n",
144 | "\n",
145 | "# Don't include systems considered to be outliers. These are systems that are\n",
146 | "# much better or worse than all other systems, so they are easy for all metrics\n",
147 | "# to rank correctly).\n",
148 | "include_outliers = False\n",
149 | "\n",
150 | "# Use MQM ratings as gold scores rather than the scores provided by the main\n",
151 | "# WMT task. Metrics tasks have used MQM for main results since 2021.\n",
152 | "gold_name = 'mqm'\n",
153 | "\n",
154 | "# Only compare metrics that have been designated as primary submissions. This\n",
155 | "# removes metric variants that are similar to each other, and reduces the size\n",
156 | "# of the comparison matrix.\n",
157 | "primary_metrics = True\n",
158 | "\n",
159 | "# Don't limit the results to a particular domain. In WMT21, domains are treated\n",
160 | "# as separate test-sets, so this is a no-op (WMT22 is a different story).\n",
161 | "domain = None\n",
162 | "\n",
163 | "# Set the number of resampling runs for determining whether one metric is better\n",
164 | "# than another according to the permutation test. We'll use 5 to make the demo\n",
165 | "# finish quickly, but at least 1000 is required for stable results.\n",
166 | "k = 5\n",
167 | "\n",
168 | "# Set the size of blocks for 'early stopping' checks during resampling. If\n",
169 | "# you're using k = 1000, this can speed up the computation, usually with\n",
170 | "# only minimal changes to the results.\n",
171 | "psd = stats.PermutationSigDiffParams(block_size = 100)\n",
172 | "\n",
173 | "# Set the p-value for deciding wheter metrics are considered to be significantly\n",
174 | "# different. Lower values make the test more stringent.\n",
175 | "pval = 0.05"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {
182 | "id": "ffPW_P5yxMbu"
183 | },
184 | "outputs": [],
185 | "source": [
186 | "# @title Evaluate metrics using global accuracy\n",
187 | "\n",
188 | "# Global accuracy, introduced by Kocmi et al (https://arxiv.org/abs/2107.10821)\n",
189 | "# is a robust way to evaluate the performance of a metric across many different\n",
190 | "# settings. The idea is to count the number of pairwise system rankings where\n",
191 | "# the metric agrees with the gold ranking, and micro average this across all\n",
192 | "# settings.\n",
193 | "\n",
194 | "# The output shows the rank of each metric's significance cluster, followed\n",
195 | "# by its accuracy, and whether it is statistically tied with (=) or better than\n",
196 | "# (\u003e) each lower-ranking metric.\n",
197 | "\n",
198 | "\n",
199 | "ranks, matrix, _, _ = data.CompareMetricsWithGlobalAccuracy(\n",
200 | " evs_list, main_refs, close_refs, include_human, include_outliers,\n",
201 | " gold_name, primary_metrics, domain, k, psd, pval)\n",
202 | "\n",
203 | "data.PrintMetricComparison(ranks, matrix, pval)"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {
210 | "id": "vCXij-xmO_E9"
211 | },
212 | "outputs": [],
213 | "source": [
214 | "# @title Evaluate metrics using system-level Pearson correlation\n",
215 | "\n",
216 | "# Pearson correlation measures the degree of linear correspondence between\n",
217 | "# metric and gold scores. Computing a single correlation across different\n",
218 | "# evalsets isn't a great idea, so the interface forces you to choose a single\n",
219 | "# set. We'll pick 'wmt21.news/en-de'. The part of the computation that extracts\n",
220 | "# relevant score vectors is factored into a separate step to allow you to\n",
221 | "# compute other correlations with these vectors.\n",
222 | "\n",
223 | "# Notice that the ranking is quite different from the accuracy ranking, partly\n",
224 | "# because we're using only a subset of the data, and partly because Pearson and\n",
225 | "# accuracy measure different things. The ranking also includes two metrics that\n",
226 | "# were automatically filtered out of the accuracy ranking because they weren't\n",
227 | "# available for all evalsets.\n",
228 | "\n",
229 | "evs = all_evs['wmt21.news/en-de']\n",
230 | "corrs = data.GetCorrelations(\n",
231 | " evs, 'sys', {evs.std_ref}, {'refB'}, include_human, include_outliers,\n",
232 | " gold_name, primary_metrics, domain)\n",
233 | "ranks, matrix, _, _ = data.CompareMetrics(\n",
234 | " corrs, scipy.stats.pearsonr, 'none', k, psd, pval)\n",
235 | "\n",
236 | "data.PrintMetricComparison(ranks, matrix, pval, evs)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {
243 | "id": "4Ur6XTs9hlmG"
244 | },
245 | "outputs": [],
246 | "source": [
247 | "# @title Evaluate metrics using segment-level Kendall correlation\n",
248 | "\n",
249 | "# Kendall correlation is similar to pairwise accuracy, except that it is\n",
250 | "# normalized differently. The function calls are identical to the previous one,\n",
251 | "# except that we set the 'level' parameter to 'seg', and specify Kendall rather\n",
252 | "# than Pearson. The value of the 'average_by' parameter also matters here, as it\n",
253 | "# specifies how system x segment score matrices get converted into vectors for\n",
254 | "# comparison. We will use 'none', which just flattens the matrices.\n",
255 | "\n",
256 | "# The resulting ranking is similar to the ranking from accuracy. One noticeable\n",
257 | "# difference is that the significance clusters are smaller because they are\n",
258 | "# based on more data (much larger vectors). Notice that BLEU is absent because\n",
259 | "# it isn't available at the segment level.\n",
260 | "\n",
261 | "evs = all_evs['wmt21.news/en-de']\n",
262 | "corrs = data.GetCorrelations(\n",
263 | " evs, 'seg', {evs.std_ref}, {'refB'}, include_human, include_outliers,\n",
264 | " gold_name, primary_metrics, domain)\n",
265 | "ranks, matrix, _, _ = data.CompareMetrics(\n",
266 | " corrs, scipy.stats.kendalltau, 'none', k, psd, pval)\n",
267 | "\n",
268 | "data.PrintMetricComparison(ranks, matrix, pval, evs)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "metadata": {
275 | "id": "x1P8jUG0kUEP"
276 | },
277 | "outputs": [],
278 | "source": [
279 | "# @title Evaluate metrics using seg-level accuracy with optimized tie threshold.\n",
280 | "\n",
281 | "# This is an implementation of the acc*_eq pairwise ranking accuracy proposed in\n",
282 | "# https://arxiv.org/abs/2305.14324. This is similar to global accuracy, but it\n",
283 | "# additionally gives metrics credit for predicting ties in gold scores, which\n",
284 | "# arise frequently in MQM segment-level data. To avoid bias due to differences\n",
285 | "# in scoring precision for different metrics, an optimal threshold for assigning\n",
286 | "# ties is automatically computed for each metric and test set.\n",
287 | "\n",
288 | "# For demo purposes we disable significance testing by setting k to 0.\n",
289 | "# (Significance testing works but is currently very slow.) Note that the\n",
290 | "# optimization procedure uses sampling, so results can change across different\n",
291 | "# runs.\n",
292 | "\n",
293 | "evs = all_evs['wmt21.news/en-de']\n",
294 | "corrs = data.GetCorrelations(\n",
295 | " evs, 'seg', {evs.std_ref}, {'refB'}, include_human, include_outliers,\n",
296 | " gold_name, primary_metrics, domain)\n",
297 | "ranks, matrix, _, _ = data.CompareMetrics(\n",
298 | " corrs, stats.KendallWithTiesOpt, 'item', 0, psd, pval, variant='acc23',\n",
299 | " sample_rate=0.1)\n",
300 | "\n",
301 | "data.PrintMetricComparison(ranks, matrix, pval, evs)"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "metadata": {
308 | "id": "Z304g56JBgq0"
309 | },
310 | "outputs": [],
311 | "source": [
312 | "# @title Evaluate a new metric\n",
313 | "\n",
314 | "# New metrics can be included in the comparison of existing metrics using the\n",
315 | "# 'extern_metrics' argument to GetCorrelations(). To demonstrate this, we'll\n",
316 | "# create and evaluate a new metric consisting of the average of the top 3\n",
317 | "# metrics in the system-level Pearson ranking.\n",
318 | "\n",
319 | "# The result is a slight, non-significant, improvement over C-SPECpn, the metric\n",
320 | "# with highest Pearson correlation. (The '*' before the new metric indicates\n",
321 | "# that it isn't recognized as a primary submission.)\n",
322 | "\n",
323 | "evs = all_evs['wmt21.news/en-de']\n",
324 | "\n",
325 | "# Create the new metric\n",
326 | "top3_metrics = ['C-SPECpn-refC', 'COMET-QE-MQM_2021-src', 'bleurt-20-refC']\n",
327 | "sys_scores = {}\n",
328 | "for sys_name in evs.sys_names:\n",
329 | " if sys_name == 'refC': continue\n",
330 | " scores = np.array([evs.Scores('sys', m)[sys_name] for m in top3_metrics])\n",
331 | " sys_scores[sys_name] = scores.mean(axis=0)\n",
332 | "\n",
333 | "# Run the comparison with the new metric included via the 'extern_metrics'\n",
334 | "# argument.\n",
335 | "extras = {'top3_avg-refC': sys_scores}\n",
336 | "corrs = data.GetCorrelations(\n",
337 | " evs, 'sys', {evs.std_ref}, {'refB'}, include_human, include_outliers,\n",
338 | " gold_name, primary_metrics, domain, extern_metrics=extras)\n",
339 | "ranks, matrix, _, _ = data.CompareMetrics(\n",
340 | " corrs, scipy.stats.pearsonr, 'none', k, psd, pval)\n",
341 | "\n",
342 | "data.PrintMetricComparison(ranks, matrix, pval, evs)"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "metadata": {
349 | "id": "Ti_N-stPqmN4"
350 | },
351 | "outputs": [],
352 | "source": [
353 | "# @title Evaluate a new metric using global accuracy\n",
354 | "\n",
355 | "# This requires a bit more work, since we have to produce results for multiple\n",
356 | "# evalsets. As before, the result is a slight gain over the best single metric\n",
357 | "# (note that the averaged metrics aren't quite the top 3 for the global accuracy\n",
358 | "# task).\n",
359 | "\n",
360 | "# Create the new metric, one instance per input evalset\n",
361 | "top3_metrics = ['C-SPECpn-\u003cREF\u003e', 'COMET-QE-MQM_2021-src', 'bleurt-20-\u003cREF\u003e']\n",
362 | "extras_list = []\n",
363 | "for evs in evs_list:\n",
364 | " top3 = [m.replace('\u003cREF\u003e', evs.std_ref) for m in top3_metrics]\n",
365 | " sys_scores = {}\n",
366 | " for sys_name in evs.sys_names:\n",
367 | " if sys_name == evs.std_ref: continue\n",
368 | " scores = np.array([evs.Scores('sys', m)[sys_name] for m in top3])\n",
369 | " sys_scores[sys_name] = scores.mean(axis=0)\n",
370 | " extras_list.append({f'top3_avg-{evs.std_ref}': sys_scores})\n",
371 | "\n",
372 | "# Run the comparison with the new metric included via the 'extern_metrics_list'\n",
373 | "# argument.\n",
374 | "ranks, matrix, _, _ = data.CompareMetricsWithGlobalAccuracy(\n",
375 | " evs_list, main_refs, close_refs, include_human, include_outliers,\n",
376 | " gold_name, primary_metrics, domain, k, psd, pval,\n",
377 | " extern_metrics_list=extras_list)\n",
378 | "\n",
379 | "data.PrintMetricComparison(ranks, matrix)"
380 | ]
381 | },
382 | {
383 | "cell_type": "markdown",
384 | "metadata": {
385 | "id": "isPvYjjHU7SA"
386 | },
387 | "source": [
388 | "# Ranking metrics using the task interface\n",
389 | "\n",
390 | "This is a higher-level interface designed to make it more convenient to compare\n",
391 | "a set of metrics using various different criteria called 'tasks'. The following\n",
392 | "code uses this interface to roughly duplicate the comparisons in the previous\n",
393 | "section."
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": null,
399 | "metadata": {
400 | "id": "CJAgSNrEz03r"
401 | },
402 | "outputs": [],
403 | "source": [
404 | "# @title Define a set of tasks\n",
405 | "\n",
406 | "# Create TaskSets from dicts that specify attribute/value-list combinations,\n",
407 | "# along with fixed assignments to other attributes. Concatenate these into a\n",
408 | "# single TaskSet.\n",
409 | "\n",
410 | "k = 1 # Use only a single random draw for demo.\n",
411 | "lang0 = {'test_set': ['wmt21.news'], 'lang': ['en-de,en-ru,zh-en']}\n",
412 | "langs = {'test_set': ['wmt21.news'], 'lang': ['en-de', 'en-ru', 'zh-en']}\n",
413 | "\n",
414 | "taskset = tasks.TaskSet(\n",
415 | " lang0, corr_fcn='accuracy', close_refs=[{'refB'}, set(), set()], k=k)\n",
416 | "taskset += tasks.TaskSet(langs, level='sys', corr_fcn='pearson', k=k)\n",
417 | "taskset += tasks.TaskSet(langs, level='seg', corr_fcn='pearson', k=k)\n",
418 | "taskset += tasks.TaskSet(\n",
419 | " langs, level='seg', avg_by='item', corr_fcn='KendallWithTiesOpt',\n",
420 | " perm_test='pairs', corr_fcn_args={'sample_rate': 0.1}, k=k)\n",
421 | "\n",
422 | "# A TaskSet is just a list of Tasks, so we can make arbitrary changes to\n",
423 | "# attribute values. In this case, set the correct close_refs for en-de tasks.\n",
424 | "\n",
425 | "for task in taskset:\n",
426 | " if task.lang == 'en-de': task.close_refs = {'refB'}\n",
427 | "\n",
428 | "# Print task 'names' (attribute/value strings in canonical order).\n",
429 | "\n",
430 | "for t in taskset:\n",
431 | " print(t.name)"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": null,
437 | "metadata": {
438 | "id": "LA_oY6dq0D9m"
439 | },
440 | "outputs": [],
441 | "source": [
442 | "# @title Run the tasks\n",
443 | "\n",
444 | "# This first loads the necessary data, then runs each task in sequence to\n",
445 | "# produce a TaskSetResults object. Subsequent runs re-use the loaded data.\n",
446 | "\n",
447 | "results = taskset.Run() # Takes about 5 minutes."
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": null,
453 | "metadata": {
454 | "id": "9ouFanD20HzV"
455 | },
456 | "outputs": [],
457 | "source": [
458 | "# @title Print raw task results\n",
459 | "\n",
460 | "for result in results:\n",
461 | " print(result.name)\n",
462 | " print(result.Str())"
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "execution_count": null,
468 | "metadata": {
469 | "id": "cC20eqtG0Mlh"
470 | },
471 | "outputs": [],
472 | "source": [
473 | "# @title Average ranks for metrics\n",
474 | "\n",
475 | "# To combine the performance of metrics across tasks, we average their task\n",
476 | "# ranks. The tasks are weighted to ensure that the total mass for important\n",
477 | "# attributes is evenly distributed among the different values those attributes\n",
478 | "# take on.\n",
479 | "weights = results.AssignWeights(tasks.Attributes())\n",
480 | "global_ranks = results.AverageRanks(weights)\n",
481 | "\n",
482 | "# It is also interesting to compare the metric performance on different subsets\n",
483 | "# of tasks, for instance split by language.\n",
484 | "ranks_by_lp = {}\n",
485 | "for val, subset in results.SplitByAttr('lang').items():\n",
486 | " weights = subset.AssignWeights(tasks.Attributes())\n",
487 | " ranks_by_lp[val] = subset.AverageRanks(weights)\n",
488 | "\n",
489 | "# Print out the comparison, with global ranks first, followed by a breakdown\n",
490 | "# by language pair. We only show metrics that are in the intersection of all\n",
491 | "# tasks.\n",
492 | "langs = [' all ' if lp == 'en-de,en-ru,zh-en' else lp for lp in ranks_by_lp]\n",
493 | "print(''.rjust(24), 'global', ' '.join(langs))\n",
494 | "for metric, rank in global_ranks.items():\n",
495 | " ranks_for_metric = [rank] + [d[metric] for d in ranks_by_lp.values()]\n",
496 | " print(f'{metric:\u003c25}', ' '.join(f'{r:5.2f}' for r in ranks_for_metric))\n"
497 | ]
498 | }
499 | ],
500 | "metadata": {
501 | "colab": {
502 | "last_runtime": {
503 | "build_target": "",
504 | "kind": "local"
505 | },
506 | "name": "mt_metrics_eval.ipynb",
507 | "private_outputs": true,
508 | "provenance": [
509 | {
510 | "file_id": "1gXA-HQKMF6G4IdrUob8hPnbVm6_rsfeX",
511 | "timestamp": 1656987947120
512 | }
513 | ],
514 | "toc_visible": true
515 | },
516 | "kernelspec": {
517 | "display_name": "Python 3",
518 | "name": "python3"
519 | },
520 | "language_info": {
521 | "name": "python"
522 | }
523 | },
524 | "nbformat": 4,
525 | "nbformat_minor": 0
526 | }
527 |
--------------------------------------------------------------------------------
/mt_metrics_eval/mtme.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | r"""Command-line interface for mt-metrics-eval.
16 |
17 | List info, print text, score metrics, compare metrics.
18 |
19 | Examples (omitting path to binary):
20 |
21 | # Get latest version of database (slow).
22 | mtme --download
23 |
24 | # Info about test sets:
25 | mtme --list # all available sets
26 | mtme --list -t wmt20 # language pairs for wmt20
27 | mtme --list -t wmt20 -l en-de # details for wmt20/en-de
28 |
29 | # Generate all system outputs, pasted with doc-ids, source, and reference:
30 | mtme -t wmt20 -l en-de --echosys doc,src,ref
31 |
32 | # Correlations for sys- and seg-level scores (using example files from the
33 | # database):
34 | MTME20=$HOME/.mt-metrics-eval/mt-metrics-eval-v2/wmt20/metric-scores
35 | mtme -t wmt20 -l en-de < $MTME20/en-de/COMET-ref.sys.score
36 | mtme -t wmt20 -l en-de < $MTME20/en-de/COMET-ref.seg.score
37 |
38 | # Correlations with alternative gold scores, outlier systems included:
39 | mtme -t wmt20 -l en-de -g mqm --use_outliers < $MTME20/en-de/COMET-ref.sys.score
40 |
41 | # Compare two metrics, testing whether correlations are significantly different:
42 | METRIC1=$MTME20/en-de/COMET-ref.sys.score
43 | METRIC2=$MTME20/en-de/BLEU-ref.sys.score
44 | mtme -t wmt20 -l en-de -g mqm -i $METRIC1 -c $METRIC2
45 |
46 | # Compare all metrics under specified conditions, writing ranks, correlations,
47 | # and matrix of pair-wise significance values (using small k for demo).
48 | mtme --matrix -t wmt20 -l en-de -g mqm --k 100
49 | """
50 |
51 | import ast
52 | import os
53 | import sys
54 | from absl import app
55 | from absl import flags
56 | from mt_metrics_eval import data
57 | from mt_metrics_eval import meta_info
58 | from mt_metrics_eval import stats
59 | from mt_metrics_eval import tasks
60 | import scipy.stats
61 | import glob
62 |
63 | flags.DEFINE_bool(
64 | 'download', False, 'Download local copy of the database and quit. '
65 | 'Overwrites any existing copy.')
66 | flags.DEFINE_bool(
67 | 'list', False, 'List available test sets. With -t, list language pairs for '
68 | 'given test set. With -t and -l, list details for given test '
69 | 'set and language pair.')
70 | flags.DEFINE_string(
71 | 'echo', None,
72 | 'A comma-separated list of text names, any of "domain", "doc", "src", '
73 | '"ref" for the main reference, or an actual reference name for any other '
74 | 'reference (see --list). Pastes the corresponding tags or text to STDOUT '
75 | 'then quits.')
76 | flags.DEFINE_string(
77 | 'echosys', None,
78 | 'Like --echo, but repeats output once for each system, with "sysname txt " '
79 | 'fields prepended.')
80 | flags.DEFINE_bool(
81 | 'scores', False,
82 | 'Dump all scores to a tsv file. For each system, write the following '
83 | 'fields for each segment: system-name, domain, doc, seg-id, then '
84 | 'segment-level, doc-level, domain-level, and system-level scores '
85 | '(whichever are available). Gold scores are written first, followed by '
86 | 'metric scores. None values are written whenever scores aren\'t available '
87 | 'for the given level and/or system.')
88 | flags.DEFINE_string(
89 | 'test_set', None, 'Test set to use (see --list).', short_name='t')
90 | flags.DEFINE_string(
91 | 'language_pair', None,
92 | 'Source-target language pair (2-char ISO639-1 codes).', short_name='l')
93 | flags.DEFINE_string(
94 | 'input',
95 | None, 'Read input from a file instead of STDIN. Each line should '
96 | 'contain a system name and a score, separated by a tab. '
97 | 'The number of entries per system determines granularity: '
98 | 'one per system, document, or segment in the test set.',
99 | short_name='i')
100 | flags.DEFINE_string(
101 | 'output', None, 'Output file, defaults to STDOUT.', short_name='o')
102 | flags.DEFINE_string(
103 | 'matrix_save', None, 'File for json/npgz output from --matrix option.')
104 | flags.DEFINE_string(
105 | 'compare',
106 | None,
107 | 'File containing scores for comparison to --input scores, in the same '
108 | 'format. Comparison can be slow due to resampled significance tests for '
109 | 'document- and segment-level scores. Set --k=1 to disable resampling.',
110 | short_name='c')
111 | flags.DEFINE_string(
112 | 'gold',
113 | 'std', 'Type of gold scores to compare to, use "std" to designate official '
114 | 'gold scores.',
115 | short_name='g')
116 | flags.DEFINE_string(
117 | 'avg', 'none',
118 | 'Averaging method for segment- or doc-level correlations: "none" to pool '
119 | 'scores into vectors, "item" to average over correlations of item '
120 | 'vectors, or "sys" to average over correlations of system vectors.')
121 | flags.DEFINE_bool(
122 | 'replace_nans_with_zeros', False,
123 | 'Replace NaNs with 0 instead of discarding them. This will penalize '
124 | 'metrics that produce NaN values because they assign all items the same '
125 | 'score.')
126 | flags.DEFINE_integer(
127 | 'k', 1000, 'Number of resampling runs for PERM-BOTH significance test.')
128 | flags.DEFINE_integer(
129 | 'k_block', 1000,
130 | 'Size of blocks for early stopping checks with PERM-BOTH test. Set to >= k '
131 | 'for no early stopping.')
132 | flags.DEFINE_float(
133 | 'early_min', 0.02,
134 | 'Early stop PERM-BOTH if pval < early_min at current block boundary.')
135 | flags.DEFINE_float(
136 | 'early_max', 0.50,
137 | 'Early stop PERM-BOTH if pval > early_max at current block boundary.')
138 | flags.DEFINE_string(
139 | 'matrix_perm_test', 'scores',
140 | 'Type of permutation test to run, one of "scores" or "pairs". The pairs '
141 | 'test only works with KendallTiesWithOpt correlation, with variant set to '
142 | '"23" or "acc23".')
143 | flags.DEFINE_float(
144 | 'thresh', -1, 'Threshold for WMT Kendall-like correlation. Defaults to 25 '
145 | 'if gold scores are WMT raw, otherwise 0. (If using --matrix, set '
146 | '--matrix_corr_args to \'{"thresh": 25}\' for the same effect.)')
147 | flags.DEFINE_bool(
148 | 'use_outliers', False,
149 | 'Include scores for outlier systems in correlation. If these scores are '
150 | 'not available in the set selected with -gold, this option has no effect.')
151 | flags.DEFINE_string(
152 | 'add_systems', '',
153 | 'Comma-separated list of systems to add to the default set for '
154 | 'correlation, for instance outlier or human output. These scores must be '
155 | 'available in the set selected with -gold.')
156 | flags.DEFINE_bool(
157 | 'matrix', False, 'Compute correlations for a set of metrics, and perform '
158 | 'significance tests on their differences. Writes metrics in descending '
159 | 'order by correlation, followed by their rank (may include ties), '
160 | 'correlation, then n significance indictors (x for filler, 1 for sig, 0 '
161 | 'for non) for comparisons between current metric and all n metrics, in the '
162 | 'same order as rows. Flags that affect this operation include all '
163 | '--matrix_* flags, along with --gold, --avg, --k, --k_block, --early_min, '
164 | '--early_max, --replace_nans_with_zeros, and --use_outliers.')
165 | flags.DEFINE_string(
166 | 'matrix_parallel', None,
167 | 'Parallelize metric comparisions, and use this value as a temp file name.')
168 | flags.DEFINE_string(
169 | 'matrix_level', 'sys', 'Granularity, one of "sys", "doc" or "seg"')
170 | flags.DEFINE_string(
171 | 'matrix_domain', None,
172 | 'Limit matrix correlations to this domain, no limit if None. The string '
173 | '"None" is also interpreted as None.')
174 | flags.DEFINE_string(
175 | 'matrix_refs', 'std',
176 | 'Reference(s) to use. Metric variants that use references outside this set '
177 | 'are excluded, as are human outputs that match any of these references. '
178 | 'Use "std" to designate the standard reference.')
179 | flags.DEFINE_string(
180 | 'matrix_close_refs', '',
181 | 'Additional reference(s) to always exclude from human outputs when '
182 | 'matrix_human is True.')
183 | flags.DEFINE_bool(
184 | 'matrix_human', False,
185 | 'Include human outputs in matrix calculation, except for references '
186 | 'specified in matrix_refs and matrix_close_refs.')
187 | flags.DEFINE_bool(
188 | 'matrix_primary', True,
189 | 'Use only primary metric submissions in the matrix.')
190 | flags.DEFINE_float(
191 | 'matrix_pval', 0.05,
192 | 'p-value to use for assigning significance to metric comparisons.')
193 | flags.DEFINE_string(
194 | 'matrix_corr', 'pearson',
195 | 'Correlation to use for --matrix, one of pearson, spearman, kendall, '
196 | 'accuracy, or any of the vector-based correlation functions defined in '
197 | 'the stats module, eg KendallVariants. '
198 | 'Accuracy is valid only for system-level comparisons. It also '
199 | 'triggers special interpretation of the --language_pair, --matrix_refs, '
200 | 'and --matrix_close_refs: language pair can be a comma-separated list, '
201 | 'with corresponding lists of refs or a single ref that gets applied to '
202 | 'all languages (it\'s not possible to specify a set of refs / language '
203 | 'with this option.')
204 | flags.DEFINE_string(
205 | 'matrix_corr_args', '{}',
206 | 'Extra arguments to the matrix_corr function, a string that can be '
207 | 'converted to a python dict, eg \'{"variant": "acc23", "epsilon": 10}\'.')
208 | flags.DEFINE_list(
209 | 'primary_metrics', None,
210 | 'List of basenames of metrics to consider primary, for example '
211 | '"BLEU,BLEURT-20,COMET". This can be used in conjunction with '
212 | '--matrix_primary to reduce the number of expensive metric pairwise '
213 | 'comparisons that need to be made.')
214 | flags.DEFINE_string(
215 | 'add_metrics_from_dir', None,
216 | 'Directory containing metric score files to add to existing metrics. This '
217 | 'may be a lower-level directory that directly contains .score files, or a '
218 | 'higher-level directory that contains language-pair sub-directories that '
219 | 'contain .score files. The latter format must be used if --language_pair '
220 | 'is a comma-separated list. New metrics are added as primary submissions, '
221 | 'and must not have the same names as existing metrics.')
222 |
223 | FLAGS = flags.FLAGS
224 |
225 |
226 | def PrintScores(evs):
227 | """Print all scores in tsv format. See doc for --scores option."""
228 |
229 | sys_names = sorted(evs.sys_names)
230 | gold_names = sorted(evs.human_score_names)
231 | metric_names = sorted(evs.metric_names)
232 |
233 | header = ''
234 | fields = ['system-name', 'domain', 'doc', 'seg-id']
235 | for level in 'seg', 'doc', 'domain', 'sys':
236 | if level in evs.levels:
237 | fields += [f'{g}:{level}' for g in gold_names]
238 | fields += [f'{m}:{level}' for m in metric_names]
239 | header = '\t'.join(fields) + '\n'
240 | docs = evs.DocsPerSeg()
241 | domains = evs.DomainsPerSeg()
242 | domain_ids = {d: i for i, d in enumerate(evs.domain_names)}
243 | doc_ids = {d: i for i, d in enumerate(evs.doc_names)}
244 |
245 | def _Score(level, scorer, sysname, ind):
246 | scores = evs.Scores(level, scorer)
247 | if scores is None or sysname not in scores:
248 | return 'None'
249 | else:
250 | return f'{scores[sysname][ind]}'
251 |
252 | fh = open(FLAGS.output, 'w') if FLAGS.output else sys.stdout
253 | with fh:
254 | fh.write(header)
255 | for n in sys_names:
256 | for i in range(len(evs.src)):
257 | doc, domain = docs[i], domains[i]
258 | doc_id, domain_id = doc_ids[doc], domain_ids[domain]
259 | fields = [n, domain, doc, f'{i + 1}']
260 | for level in 'seg', 'doc', 'domain', 'sys':
261 | if level not in evs.levels:
262 | continue
263 | ind = {'seg': i, 'doc': doc_id, 'domain': domain_id, 'sys': 0}[level]
264 | fields += [_Score(level, g, n, ind) for g in gold_names]
265 | fields += [_Score(level, m, n, ind) for m in metric_names]
266 | fh.write('\t'.join(fields) + '\n')
267 |
268 |
269 | def Flag2TaskArg(flag_val, sets=False):
270 | """Convert gold and ref flag values to task arguments."""
271 | if flag_val == 'std' or flag_val == 'None' or not flag_val:
272 | return None
273 | vals = flag_val.split(',')
274 | if sets:
275 | # Limited to singleton sets.
276 | vals = [{v} for v in vals]
277 | return vals[0] if len(vals) == 1 else vals
278 |
279 |
280 | def EvsDict(new_metric_dirs):
281 | """Make a (testset, lp)->evs dict w/ added metrics, None if not necessary."""
282 | if not new_metric_dirs and not FLAGS.primary_metrics:
283 | return None
284 | evs_dict = {}
285 | num_metrics_added = 0
286 | for lp in FLAGS.language_pair.split(','):
287 | evs = data.EvalSet(FLAGS.test_set, lp, True)
288 | evs_dict[(FLAGS.test_set, lp)] = evs
289 | if FLAGS.primary_metrics:
290 | evs.SetPrimaryMetrics(set(FLAGS.primary_metrics))
291 | if lp in new_metric_dirs:
292 | new_metrics = evs.AddMetricsFromDir(new_metric_dirs[lp], repair=True)
293 | num_metrics_added += len(new_metrics)
294 | evs.SetPrimaryMetrics(evs.primary_metrics | set(new_metrics))
295 | if new_metric_dirs and not num_metrics_added:
296 | raise ValueError('No new metrics added by --add_metrics_from_dir')
297 | return evs_dict
298 |
299 |
300 | def PrintMatrix(new_metric_dirs):
301 | """Print ranks, correlations, and comparison matrix for a set of metrics."""
302 |
303 | task = tasks.Task(
304 | test_set=FLAGS.test_set,
305 | lang=FLAGS.language_pair,
306 | domain=None if FLAGS.matrix_domain == 'None' else FLAGS.matrix_domain,
307 | level=FLAGS.matrix_level,
308 | human=FLAGS.matrix_human,
309 | avg_by=FLAGS.avg,
310 | corr_fcn=FLAGS.matrix_corr,
311 | k=FLAGS.k,
312 | gold=Flag2TaskArg(FLAGS.gold),
313 | refs=Flag2TaskArg(FLAGS.matrix_refs, sets=True),
314 | close_refs=Flag2TaskArg(FLAGS.matrix_close_refs, sets=True),
315 | use_outliers=FLAGS.use_outliers,
316 | primary=FLAGS.matrix_primary,
317 | pval=FLAGS.matrix_pval,
318 | block_size=FLAGS.k_block,
319 | early_min=FLAGS.early_min,
320 | early_max=FLAGS.early_max,
321 | replace_nans_with_zeros=FLAGS.replace_nans_with_zeros,
322 | perm_test=FLAGS.matrix_perm_test,
323 | corr_fcn_args=ast.literal_eval(FLAGS.matrix_corr_args)
324 | )
325 | evs_dict = EvsDict(new_metric_dirs)
326 | task_results = task.Run(
327 | parallel_file=FLAGS.matrix_parallel, eval_set_dict=evs_dict)
328 | fh = open(FLAGS.output, 'w') if FLAGS.output else sys.stdout
329 | with fh:
330 | fh.write(task_results.name + '\n')
331 | fh.write(task_results.Str())
332 | if FLAGS.matrix_save:
333 | task_results.Save(FLAGS.matrix_save)
334 |
335 |
336 | def PrintCorrelation(evs, scorefile, tag, outfile):
337 | """Read scores from score file, print correlation stats, return values."""
338 |
339 | scores = data.ReadScoreFile(scorefile)
340 | if not scores:
341 | raise ValueError('No systems in input file %s' % scorefile)
342 | num_scores = len(list(scores.values())[0])
343 | if num_scores == 1:
344 | level = 'sys'
345 | elif num_scores == len(evs.docs):
346 | level = 'doc'
347 | elif num_scores == len(evs.src):
348 | level = 'seg'
349 | else:
350 | raise ValueError(
351 | 'Number of scores/system (%d) doesn\'t match any known granularity in '
352 | '%s/%s' % (num_scores, FLAGS.test_set, FLAGS.language_pair))
353 |
354 | std_scorer = evs.StdHumanScoreName(level)
355 | gold_name = std_scorer if FLAGS.gold == 'std' else FLAGS.gold
356 | gold_scores = evs.Scores(level, gold_name)
357 | if gold_scores is None:
358 | raise ValueError('No scores for %s at %s level.' % (FLAGS.gold, level))
359 | sys_names = set(gold_scores) - evs.human_sys_names
360 | if not FLAGS.use_outliers:
361 | sys_names -= evs.outlier_sys_names
362 | for n in [s for s in FLAGS.add_systems.split(',') if s]:
363 | if n not in gold_scores:
364 | raise ValueError(f'No {gold_name} scores for system {n}')
365 | sys_names.add(n)
366 |
367 | avg = 'none' if level == 'sys' else FLAGS.avg
368 | corr = evs.Correlation(gold_scores, scores, sys_names)
369 | pearson = corr.Pearson(FLAGS.avg)
370 | spearman = corr.Spearman(FLAGS.avg)
371 | kendall = corr.Kendall(FLAGS.avg)
372 | # Always use item-wise averaging with KendallLike, otherwise it's very slow.
373 | if FLAGS.thresh == -1:
374 | FLAGS.thresh = 25 if gold_name == 'wmt-raw' else 0
375 | kendall_like = corr.KendallLike(thresh=FLAGS.thresh)
376 |
377 | if avg == 'none':
378 | cmp = 'flattened'
379 | elif avg == 'sys':
380 | cmp = 'rows in'
381 | else:
382 | cmp = 'columns in'
383 |
384 | print(
385 | f'{tag}{FLAGS.test_set} {FLAGS.language_pair} {level}-level: '
386 | f'scoring {corr.num_sys}/{len(evs.sys_names)} systems, '
387 | f'gold={gold_name}, '
388 | f'comparing {cmp} {corr.num_sys}x{corr.num_items} matrices '
389 | f'({corr.none_count} None vals): '
390 | f'Pearson={pearson[0]:0.3f},p{pearson[1]:0.3f} '
391 | f'Spearman={spearman[0]:0.3f},p{spearman[1]:0.3f} '
392 | f'Kendall={kendall[0]:0.3f},p{kendall[1]:0.3f} '
393 | f'Kendall-like@{FLAGS.thresh:g}={kendall_like[0]:0.3f}',
394 | file=outfile)
395 |
396 | return corr, pearson, spearman, kendall, kendall_like
397 |
398 |
399 | def PrintComparison(res_base, res_comp, outfile):
400 | """Test for difference between correlations, and print results."""
401 | corr1, pears1, spear1, kend1, kendlike1 = res_base
402 | corr2, pears2, spear2, kend2, kendlike2 = res_comp
403 | if corr1.num_items != corr2.num_items:
404 | raise ValueError('Can\'t compare score files at different granularities.')
405 |
406 | pearson = corr1.AverageCorrelation(
407 | scipy.stats.pearsonr, FLAGS.avg, FLAGS.replace_nans_with_zeros)
408 | spearman = corr1.AverageCorrelation(
409 | scipy.stats.spearmanr, FLAGS.avg, FLAGS.replace_nans_with_zeros)
410 | kendall = corr1.AverageCorrelation(
411 | scipy.stats.kendalltau, FLAGS.avg, FLAGS.replace_nans_with_zeros)
412 | # Always average KendallLike, otherwise it's very slow.
413 | kendlike = corr1.AverageCorrelation(
414 | stats.KendallLike, 'item', FLAGS.replace_nans_with_zeros,
415 | thresh=FLAGS.thresh)
416 |
417 | def _SigTest(corr1, corr2, v1, v2, corr_wrapper, corr_fcn):
418 | better = v2[0] >= v1[0]
419 | if not better:
420 | corr2, corr1 = corr1, corr2
421 | w = stats.WilliamsSigDiff(corr1, corr2, corr_wrapper)
422 | p, _, _, _ = stats.PermutationSigDiff(
423 | corr1, corr2, corr_fcn, FLAGS.avg, FLAGS.k,
424 | stats.PermutationSigDiffParams(
425 | FLAGS.k_block, FLAGS.early_min, FLAGS.early_max),
426 | FLAGS.replace_nans_with_zeros)
427 | return better, w, p
428 |
429 | pear_b, pear_w, pear_p = _SigTest(
430 | corr1, corr2, pears1, pears2, pearson, scipy.stats.pearsonr)
431 | sper_b, sper_w, sper_p = _SigTest(
432 | corr1, corr2, spear1, spear2, spearman, scipy.stats.spearmanr)
433 | kend_b, kend_w, kend_p = _SigTest(
434 | corr1, corr2, kend1, kend2, kendall, scipy.stats.kendalltau)
435 | kl_b, kl_w, kl_p = _SigTest(
436 | corr1, corr2, kendlike1, kendlike2, kendlike, stats.KendallLike)
437 |
438 | def _Summary(better, sig_williams, sig_perm):
439 | s = '2>1,' if better else '1>2,'
440 | s += f'pWilliams={sig_williams[0]:0.3f},pPERM={sig_perm:0.3f}'
441 | return s
442 |
443 | print(
444 | 'Pearson:%s Spearman:%s Kendall:%s Kendall-like@%g:%s' %
445 | (_Summary(pear_b, pear_w, pear_p), _Summary(sper_b, sper_w, sper_p),
446 | _Summary(kend_b, kend_w, kend_p), FLAGS.thresh, _Summary(
447 | kl_b, kl_w, kl_p)),
448 | file=outfile)
449 |
450 |
451 | def GetNewMetricDirs():
452 | """Parse the new_metric_dirs flag, and return map from lp->metric_dir."""
453 | lps = FLAGS.language_pair.split(',')
454 | new_metric_dirs = {}
455 | if FLAGS.add_metrics_from_dir:
456 | for lp in lps:
457 | new_dir = os.path.join(FLAGS.add_metrics_from_dir, lp)
458 | if os.path.isdir(new_dir):
459 | new_metric_dirs[lp] = new_dir
460 | if len(lps) == 1 and lps[0] not in new_metric_dirs:
461 | new_metric_dirs[lps[0]] = FLAGS.add_metrics_from_dir
462 | if not new_metric_dirs:
463 | raise ValueError(
464 | 'No suitable directories found for --add_metrics_from_dir flag.')
465 | return new_metric_dirs
466 |
467 |
468 | def main(argv):
469 | if len(argv) > 1:
470 | raise app.UsageError('Too many command-line arguments.')
471 |
472 | if FLAGS.download:
473 | print('Downloading data into %s' % data.LocalDir())
474 | data.Download()
475 | return
476 |
477 | if FLAGS.list:
478 | if FLAGS.test_set is None:
479 | print('test-sets:', ' '.join(meta_info.DATA))
480 | elif FLAGS.language_pair is None:
481 | print(f'language pairs for {FLAGS.test_set}:',
482 | ' '.join(meta_info.DATA[FLAGS.test_set]))
483 | else:
484 | evs = data.EvalSet(FLAGS.test_set, FLAGS.language_pair)
485 | print(
486 | '%s %s:' % (FLAGS.test_set, FLAGS.language_pair),
487 | '%d segs, %d docs, %d systems (includes %d outliers + %d human), '
488 | 'outliers: {%s}, human: {%s}, refs: {%s}, gold-scores: {%s}' %
489 | (len(evs.src), len(evs.docs), len(evs.sys_names),
490 | len(evs.outlier_sys_names), len(evs.human_sys_names), ','.join(
491 | evs.outlier_sys_names), ','.join(evs.human_sys_names), ','.join(
492 | evs.all_refs), ','.join(evs.human_score_names)))
493 | return
494 |
495 | if FLAGS.test_set is None:
496 | raise ValueError('No test_set specified.')
497 | if FLAGS.language_pair is None:
498 | raise ValueError('No language_pair specified.')
499 |
500 | new_metric_dirs = GetNewMetricDirs()
501 |
502 | if FLAGS.matrix:
503 | PrintMatrix(new_metric_dirs)
504 | return
505 |
506 | evs = data.EvalSet(
507 | FLAGS.test_set, FLAGS.language_pair,
508 | read_stored_metric_scores=FLAGS.scores)
509 | if FLAGS.primary_metrics:
510 | evs.SetPrimaryMetrics(set(FLAGS.primary_metrics))
511 | if FLAGS.language_pair in new_metric_dirs:
512 | new_metrics = evs.AddMetricsFromDir(
513 | new_metric_dirs[FLAGS.language_pair], repair=True)
514 | evs.SetPrimaryMetrics(evs.primary_metrics | set(new_metrics))
515 | if not new_metrics:
516 | raise ValueError('No new metrics added by --add_metrics_from_dir')
517 |
518 | if FLAGS.scores:
519 | PrintScores(evs)
520 | return
521 |
522 | if FLAGS.echo is not None or FLAGS.echosys is not None:
523 | flag_val = FLAGS.echo or FLAGS.echosys
524 | texts = []
525 | for col in flag_val.split(','):
526 | if col == 'src':
527 | texts.append(evs.src)
528 | elif col == 'doc':
529 | texts.append(evs.DocsPerSeg())
530 | elif col == 'domain':
531 | texts.append(evs.DomainsPerSeg())
532 | elif col == 'ref':
533 | texts.append(evs.all_refs[evs.std_ref])
534 | elif col in evs.all_refs:
535 | texts.append(evs.all_refs[col])
536 | else:
537 | raise ValueError('Unknown text type for --echo: %s' % col)
538 | if FLAGS.echo is not None:
539 | for lines in zip(*texts):
540 | print('\t'.join(lines))
541 | else:
542 | for sysname, sysout in evs.sys_outputs.items():
543 | for lines in zip(sysout, *texts):
544 | print('%s\t%s' % (sysname, '\t'.join(lines)))
545 | return
546 |
547 | fh = open(FLAGS.output, 'w') if FLAGS.output else sys.stdout
548 | with fh:
549 | tag = '1: ' if FLAGS.compare else ''
550 | res_base = PrintCorrelation(evs, FLAGS.input or '/dev/stdin', tag, fh)
551 | if FLAGS.compare:
552 | res_comp = PrintCorrelation(evs, FLAGS.compare, '2: ', fh)
553 | PrintComparison(res_base, res_comp, fh)
554 |
555 |
556 | if __name__ == '__main__':
557 | app.run(main)
558 |
--------------------------------------------------------------------------------
/mt_metrics_eval/pce.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2024 Brian Thompson. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import numpy as np
18 |
19 |
20 | def compute_pairwise_p_values(seg_scores, num_permutations=1000, seed: int = 4):
21 | """
22 | Author: Brian Thompson
23 | Date: June 2024
24 |
25 | Suppose we have test set consisting of L=5 segments, and two systems, systemsA and systemB,
26 | for which we have segment-level scores scoresA and scoresB:
27 | scoresA = [0.8, 0.9, 0.7, 1.0, 0.6]
28 | scoresB = [0.2, 0.3, 0.1, 0.4, 0.0]
29 |
30 | Typically we would average segment-level scores to get system level scores, but for convenience later on
31 | we will define system scores to be the sum of segment-level scores. This gives us a delta system-level score of:
32 | test_delta = sum(scoresA) - sum(scoresB) = 4.0 - 1.0 = 3.0
33 |
34 | To run a paired permutation test, we first generate a new set of scores scores0,
35 | where each score0[i] is randomly selected from either scoresA[i] or scoresB[i].
36 | Let's define a random boolean mask:
37 | m = [1, 0, 0, 1, 1]
38 |
39 | and used it to select scores0:
40 | scores0 = m.*scoresA + (1-m).*scoresB = [0.8, 0.3, 0.1, 1.0, 0.6] # selected from [A, B, B, A, A], respectively
41 |
42 | Likewise, we compose scores1 using all the scores which were not selected for scores0:
43 | scores1 = (1-m).*scoresA + m.*scoresB = [0.2, 0.9, 0.7, 0.4, 0.0] # selected from [B, A, A, B, B], respectively
44 |
45 | To get the delta system-level score for our two mock systems, we need to compute:
46 | null_delta = sum(scores0) - sum(scores1)
47 | = sum(m.*scoresA + (1-m).*scoresB) - sum((1-m).*scoresA + m.*scoresB)
48 | = sum((2m-1).*scoresA) - sum((2m-1).*scoresB
49 | = (2m-1) * scoresA.T - (2m-1) * scoresB.T
50 | = [ 1, -1, -1, 1, 1] * [[0.8], - [ 1, -1, -1, 1, 1] * [[0.2], = 0.8 - 0.2 = 0.6
51 | [0.9], [0.3],
52 | [0.7], [0.1],
53 | [1.0], [0.4],
54 | [0.6]] [0.0]]
55 |
56 | To compute many different permutations, we replace the vector m with a matrix of size (num_permutations, L):
57 | null_delta = [[ 1, 1, -1, -1, -1], * [[0.8], - [[ 1, 1, -1, -1, -1], * [[0.2], = [[-0.6], - [[ 0.0], = [[-0.6]
58 | [ 1, -1, 1, -1, 1], [0.9], [ 1, -1, 1, -1, 1], [0.3], [ 0.2], [-0.4], [ 0.6],
59 | [ 1, -1, 1, 1, -1], [0.7], [ 1, -1, 1, 1, -1], [0.1], [ 1.0], [ 0.4], [ 0.6],
60 | [-1, 1, -1, -1, 1], [1.0], [-1, 1, -1, -1, 1], [0.4], [-1.0], [-0.4], [-0.6],
61 | [ 1, 1, 1, -1, 1], [0.6]] [ 1, 1, 1, -1, 1], [0.0]] [ 2.0], [ 0.2], [ 1.8],
62 | [-1, 1, -1, 1, -1], [-1, 1, -1, 1, -1], [-0.2], [ 0.4], [-0.6],
63 | [ 1, 1, 1, 1, 1], [ 1, 1, 1, 1, 1], [ 4.0], [ 1.0], [ 3.0],
64 | [ 1, -1, 1, -1, 1], [ 1, -1, 1, -1, 1], [ 0.2], [-0.4], [ 0.6],
65 | [ 1, 1, -1, -1, 1], [ 1, 1, -1, -1, 1], [ 0.6], [ 0.0], [ 0.6],
66 | [-1, 1, -1, -1, -1]] [ 1, -1, -1, 1, -1]] [-2.2]] [-0.4]] [-1.8]]
67 |
68 | To test the significance that system A is better than system B, we compute:
69 | null_delta >= test_delta = [[-0.6] >= 3 = [[False],
70 | [ 0.6], [False],
71 | [ 0.6], [False],
72 | [-0.6], [False],
73 | [ 1.8], [False],
74 | [-0.6], [False],
75 | [ 3.0], [True ],
76 | [ 0.6], [False],
77 | [ 0.6], [False],
78 | [-1.8]] [False]]
79 |
80 | The p value is the fraction of the time that null_delta >= test_delta, in this case 1/10 = 0.1
81 |
82 | The above discussion was for a single system pair, but we actually need to compute p values for each pairwise
83 | within a set systems systemA, systemB, ... systemN. In practice, the computation bottleneck is generating
84 | the random boolean vector m, so we generate m once and use it for all pairs of systems.
85 |
86 | Reusing m also allows us to avoid most of the N^2 computations by pre-computing (2m-1) * scoresA.T,
87 | (2m-1) * scoresB.T, ..., (2m-1) * scoresN.T.
88 |
89 | Test speed:
90 | python -m timeit -s "import numpy as np; from pairwise_paired_permutation_test import compute_pairwise_p_values; x=np.random.random(size=(14,1300))" "compute_pairwise_p_values(x, num_permutations=1000)"
91 |
92 | :param seg_scores: segment-level scores, with shape (num_systems, num_segments)
93 | :param num_permutations: Number of permutations for permutation test
94 | :param seed: The random seed
95 | :return: np.array of size (num_systems, num_systems), where the upper triangle has been populated
96 | with p-values for the hypothesis that system[i] > system[j]
97 | """
98 | num_systems, num_segments = seg_scores.shape
99 |
100 | rng = np.random.default_rng(seed)
101 | # initialize in range [0, 1)
102 | two_m_minus_one = rng.random(size=(num_permutations, num_segments), dtype=np.float32)
103 | # quantize to 0 or 1, in place
104 | np.rint(two_m_minus_one, out=two_m_minus_one, casting='same_kind')
105 | # scale and shift to get -1.0 and +1.0, in place
106 | two_m_minus_one *= 2.0
107 | two_m_minus_one -= 1.0
108 |
109 | seg_scores = seg_scores.astype(np.float32) # shape: (num_systems, num_segments)
110 | sys_scores = np.sum(seg_scores, axis=1) # shape: (num_systems, )
111 |
112 | partial = np.matmul(two_m_minus_one, seg_scores.T) # shape: (num_permutations, num_systems)
113 |
114 | # initialize p value matrix to NaN
115 | p_vals = np.empty((num_systems, num_systems,)) * np.nan
116 | # populate upper triangle
117 | for ii in range(num_systems):
118 | for jj in range(ii + 1, num_systems):
119 | null_delta = partial[:, ii] - partial[:, jj] # shape: (num_permutations, )
120 | test_delta = sys_scores[ii] - sys_scores[jj] # float
121 | p_vals[ii, jj] = np.sum(null_delta >= test_delta) / num_permutations
122 |
123 | return p_vals
124 |
125 |
126 | def compute_one_minus_pce(human_pairwise_p_vals, metric_pairwise_p_vals):
127 | """
128 | Author: Brian Thompson
129 | Date: June 2024
130 |
131 | Pairwise Confidence Error (PCE) is the absolute difference between
132 | the p value for the conclusion that one system is better than another given human judgements and
133 | the p value for the conclusion for the same system comparison given metric judgements,
134 | averaged over all system pairings for a set of systems.
135 |
136 | We return 1-PCE to be comparable with pairwise accuracy [i.e. range from 0 to 1, higher is better]
137 |
138 | :param human_pairwise_p_vals: np.array of shape (num_systems, num_systems),
139 | where the upper triangle has been populated with p-values for system[i] > system[j]
140 | computed from human judgements
141 | :param metric_pairwise_p_vals: np.array of shape (num_systems, num_systems),
142 | where the opper triangle has been populated with p-values for system[i] > system[j]
143 | computed from metric scores
144 | :return: 1-PCE
145 | """
146 | num_systems = human_pairwise_p_vals.shape[0]
147 | upper_tri_idxs = np.triu_indices(num_systems, 1)
148 | return 1.0 - np.mean(np.abs(human_pairwise_p_vals - metric_pairwise_p_vals)[upper_tri_idxs])
149 |
150 |
151 |
--------------------------------------------------------------------------------
/mt_metrics_eval/pce_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tests for soft pairwise accuracy (pairwise comparison error)."""
16 |
17 | from mt_metrics_eval import pce
18 | import numpy as np
19 | import unittest
20 |
21 |
22 | class PCETest(unittest.TestCase):
23 |
24 | def test_pairwise_p_values_are_deterministic(self):
25 | scores = np.random.rand(10, 100)
26 | pvalues1 = pce.compute_pairwise_p_values(scores)
27 | pvalues2 = pce.compute_pairwise_p_values(scores)
28 | np.testing.assert_array_equal(pvalues1, pvalues2)
29 |
30 |
31 | if __name__ == "__main__":
32 | unittest.main()
33 |
--------------------------------------------------------------------------------
/mt_metrics_eval/ratings.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Ratings consisting of sub-segment-level error spans."""
16 |
17 | import collections
18 | import dataclasses
19 | import json
20 | from typing import Any
21 | import dacite
22 | import glob
23 |
24 |
25 | @dataclasses.dataclass
26 | class Error:
27 | """A representation of an error span.
28 |
29 | Attributes:
30 | start: The starting character offset of the span in the original text.
31 | end: The end+1 character offset of the span in the original text.
32 | category: The category.
33 | severity: The severity.
34 | score: The original score assigned to this error by a rater or a model.
35 | is_source_error: True if the span is in the source text.
36 | """
37 |
38 | start: int
39 | end: int
40 | category: str | None = None
41 | severity: str | None = None
42 | score: float | None = None
43 | is_source_error: bool = False
44 |
45 | def ToDict(self) -> dict[str, Any]:
46 | return dataclasses.asdict(self)
47 |
48 | @classmethod
49 | def FromDict(cls, d: dict[str, Any]) -> 'Error':
50 | return dacite.from_dict(data_class=Error, data=d)
51 |
52 |
53 | @dataclasses.dataclass
54 | class Rating:
55 | """The errors assigned to a translation by a single rater/method."""
56 |
57 | errors: list[Error]
58 |
59 | def ToDict(self) -> dict[str, Any]:
60 | return dataclasses.asdict(self)
61 |
62 | @classmethod
63 | def FromDict(cls, d: dict[str, Any]) -> 'Rating':
64 | return dacite.from_dict(data_class=Rating, data=d)
65 |
66 |
67 | def ReadRatingFile(
68 | filename: str, default_rater: str
69 | ) -> tuple[dict[str, list[Rating | None]], dict[str, list[str]]]:
70 | """Read a file containing sysname/rating entries."""
71 | ratings = collections.defaultdict(list) # sys -> [ratings]
72 | rater_ids = collections.defaultdict(list) # sys -> [rater_id]
73 | with open(filename) as f:
74 | for line in f:
75 | cols = line.strip().split('\t')
76 | if len(cols) == 2:
77 | sysname, rating = cols
78 | rater = default_rater
79 | elif len(cols) == 3:
80 | sysname, rating, rater = cols
81 | else:
82 | raise ValueError(
83 | f'Expected 2 or 3 columns in rating file. Found {len(cols)}. Line:'
84 | f' {line}'
85 | )
86 | if rating == 'None':
87 | ratings[sysname].append(None)
88 | rater_ids[sysname].append(None)
89 | else:
90 | ratings[sysname].append(Rating.FromDict(json.loads(rating)))
91 | rater_ids[sysname].append(rater)
92 | return ratings, rater_ids
93 |
94 |
95 | def WriteRatingFile(
96 | ratings: dict[str, list[Rating | None]],
97 | filename: str,
98 | rater_ids_dict: dict[str, list[str | None]],
99 | ):
100 | """Write a file containing sysname/rating entries."""
101 | with open(filename, 'w') as f:
102 | for sysname, rating_list in sorted(ratings.items()):
103 | rater_ids = rater_ids_dict[sysname]
104 | for rating, rater_id in zip(rating_list, rater_ids):
105 | if rating is None:
106 | f.write(f'{sysname}\tNone\n')
107 | else:
108 | f.write(f'{sysname}\t{json.dumps(rating.ToDict())}\t{rater_id}\n')
109 |
--------------------------------------------------------------------------------
/mt_metrics_eval/ratings_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tests for the ratings module."""
16 |
17 | from mt_metrics_eval import ratings
18 | import unittest
19 |
20 |
21 | class ErrorTest(unittest.TestCase):
22 |
23 | def test_error_serialization(self):
24 | error = ratings.Error(1, 4, 'cat', None, 5, True)
25 | serialized = error.ToDict()
26 | expected = {
27 | 'start': 1,
28 | 'end': 4,
29 | 'category': 'cat',
30 | 'severity': None,
31 | 'score': 5,
32 | 'is_source_error': True,
33 | }
34 | self.assertEqual(serialized, expected)
35 | deserialized = ratings.Error.FromDict(serialized)
36 | self.assertEqual(deserialized, error)
37 |
38 |
39 | class RatingTest((unittest.TestCase)):
40 |
41 | def test_rating_serialization(self):
42 | rating = ratings.Rating(
43 | [
44 | ratings.Error(1, 4, 'cat', None, 5, True),
45 | ratings.Error(0, 1),
46 | ]
47 | )
48 | serialized = rating.ToDict()
49 | expected = {
50 | 'errors': [
51 | {
52 | 'start': 1,
53 | 'end': 4,
54 | 'category': 'cat',
55 | 'severity': None,
56 | 'score': 5,
57 | 'is_source_error': True,
58 | },
59 | {
60 | 'start': 0,
61 | 'end': 1,
62 | 'category': None,
63 | 'severity': None,
64 | 'score': None,
65 | 'is_source_error': False,
66 | },
67 | ]
68 | }
69 | self.assertEqual(serialized, expected)
70 | deserialized = ratings.Rating.FromDict(serialized)
71 | self.assertEqual(deserialized, rating)
72 |
73 |
74 | if __name__ == "__main__":
75 | unittest.main()
76 |
--------------------------------------------------------------------------------
/mt_metrics_eval/standalone_ratings.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Self-contained ratings consisting of sub-segment-level error spans."""
16 |
17 | import dataclasses
18 | import json
19 | from typing import Any
20 | import dacite
21 | from mt_metrics_eval import data
22 | from mt_metrics_eval import ratings
23 | import glob
24 |
25 |
26 | @dataclasses.dataclass
27 | class Rating:
28 | """The errors assigned to a translation by a single rater/method.
29 |
30 | Attributes:
31 | source: The source text.
32 | hypothesis: The translation.
33 | errors: The list of errors.
34 | document_id: The ID of the document where the source text comes from.
35 | segment_id: The 0-indexed offset of this segment in the test set.
36 | system_id: The ID of the system that generated the translation.
37 | rater_id: The ID of the rater/method that annotated the errors.
38 | src_lang: The source language code.
39 | tgt_lang: The target language code.
40 | """
41 |
42 | source: str
43 | hypothesis: str
44 | errors: list[ratings.Error]
45 | document_id: str | None = None
46 | segment_id: int | None = None
47 | system_id: str | None = None
48 | rater_id: str | None = None
49 | src_lang: str | None = None
50 | tgt_lang: str | None = None
51 |
52 | def ToDict(self) -> dict[str, Any]:
53 | return dataclasses.asdict(self)
54 |
55 | @classmethod
56 | def FromDict(cls, d: dict[str, Any]) -> 'Rating':
57 | return dacite.from_dict(data_class=Rating, data=d)
58 |
59 |
60 | def ReadRatingFile(filename) -> list[Rating]:
61 | """Read a file containing a list of Ratings."""
62 | ratings_list = []
63 | with open(filename) as f:
64 | for line in f:
65 | ratings_list.append(Rating.FromDict(json.loads(line)))
66 | return ratings_list
67 |
68 |
69 | def WriteRatingFile(ratings_list: list[Rating], filename):
70 | """Write a list of Ratings to file."""
71 | with open(filename, 'w') as f:
72 | for rating in ratings_list:
73 | f.write(f'{json.dumps(rating.ToDict())}\n')
74 |
75 |
76 | def _RenameRaters(
77 | ratings_list: list[Rating], anonymize: bool
78 | ) -> dict[Any, str]:
79 | """Rename original rater names."""
80 | rater_ids = sorted(set(rating.rater_id for rating in ratings_list))
81 | # If all raters start with 'rater', don't anonymize to avoid a confusing
82 | # renaming due to sorting rater10 before rater2.
83 | if anonymize and not all(r.startswith('rater') for r in rater_ids if r):
84 | return {rater: f'rater{i + 1}' for i, rater in enumerate(rater_ids)}
85 | else:
86 | if None in rater_ids and 'rater' in rater_ids:
87 | raise ValueError(
88 | 'Attempt to rename rater "None" to "rater" failed because "rater"'
89 | ' already exists.'
90 | )
91 | return {
92 | rater: (rater if rater is not None else 'rater') for rater in rater_ids
93 | }
94 |
95 |
96 | def _CheckRating(
97 | rating: Rating, evs: data.EvalSet, rating_id: int, strict: bool = True
98 | ):
99 | """Check rating for compatibility with evs, with text match if strict."""
100 | if rating.segment_id is None or rating.system_id is None:
101 | raise ValueError(
102 | f'Rating {rating_id}: conversion requires non-null segment and system '
103 | 'ids.'
104 | )
105 | rating_id += 1 # 1-based for human consumption
106 | seg = rating.segment_id
107 | if seg >= len(evs.src):
108 | raise ValueError(f'Segment offset is too big in rating {rating_id}: {seg}')
109 | if rating.document_id is not None:
110 | if rating.document_id not in evs.doc_names:
111 | raise ValueError(
112 | f'Unknown doc in rating {rating_id}: {rating.document_id}')
113 | doc_beg, doc_end = evs.docs[rating.document_id]
114 | if seg < doc_beg or seg >= doc_end:
115 | raise ValueError(
116 | f'Bad segment offset for doc {rating.document_id} in rating '
117 | '{rating_id}: {seg}')
118 | if rating.system_id not in evs.sys_names:
119 | raise ValueError(f'Unknown sys in rating {rating_id}: {rating.system_id}')
120 | if rating.src_lang is not None and rating.src_lang != evs.src_lang:
121 | raise ValueError(
122 | f'Bad source language in rating {rating_id}: {rating.src_lang}')
123 | if rating.tgt_lang is not None and rating.tgt_lang != evs.tgt_lang:
124 | raise ValueError(
125 | f'Bad target language in rating {rating_id}: {rating.tgt_lang}')
126 | if strict:
127 | # We assume that the rating is internally consistent, so if the source and
128 | # hypothesis match evs, all error spans will be in range.
129 | if rating.source != evs.src[seg]:
130 | raise ValueError(f'Source segment mismatch in rating {rating_id}')
131 | if rating.hypothesis != evs.sys_outputs[rating.system_id][seg]:
132 | raise ValueError(f'Hypothesis segment mismatch in rating {rating_id}')
133 |
134 |
135 | def RatingsListToEvalSetRatings(
136 | ratings_list: list[Rating],
137 | evs: data.EvalSet,
138 | anonymize_raters: bool = False,
139 | strict: bool = True,
140 | ) -> tuple[
141 | dict[str, dict[str, list[ratings.Rating | None]]],
142 | dict[str, str],
143 | dict[str, dict[str, list[str | None]]],
144 | ]:
145 | """Convert Ratings list to EvalSet-style ratings dict and rater rename map."""
146 | new_rater_names = _RenameRaters(ratings_list, anonymize_raters)
147 | ratings_dict = {} # rating_name -> {sys: [rating]}
148 | rater_ids_dict = {} # rating_name -> {sys: [rater_id]}
149 | for rating_id, rating in enumerate(ratings_list):
150 | _CheckRating(rating, evs, rating_id, strict)
151 | rater = new_rater_names[rating.rater_id]
152 | if rater not in ratings_dict:
153 | ratings_dict[rater] = {s: [None] * len(evs.src) for s in evs.sys_names}
154 | rater_ids_dict[rater] = {s: [None] * len(evs.src) for s in evs.sys_names}
155 | if ratings_dict[rater][rating.system_id][rating.segment_id] is not None:
156 | # Nothing in the Rating spec precludes this, but it's probably something
157 | # we want to enforce.
158 | raise ValueError(
159 | f'Rating already exists for system/rater/segment: {rating_id}'
160 | )
161 | evs_rating = ratings.Rating(rating.errors)
162 | ratings_dict[rater][rating.system_id][rating.segment_id] = evs_rating
163 | rater_ids_dict[rater][rating.system_id][rating.segment_id] = rater
164 | return ratings_dict, new_rater_names, rater_ids_dict
165 |
166 |
167 | def MergeEvalSetRaters(
168 | evs_ratings: dict[str, dict[str, list[ratings.Rating | None]]],
169 | evs: data.EvalSet,
170 | evs_rater_ids: dict[str, dict[str, list[str | None]]],
171 | ) -> tuple[dict[str, list[ratings.Rating | None]], dict[str, list[str | None]]]:
172 | """Merge disjoint ratings from multiple raters into single-rater dict."""
173 | new_ratings = {s: [None] * len(evs.src) for s in evs.sys_names}
174 | new_rater_ids = {s: [None] * len(evs.src) for s in evs.sys_names}
175 | for rating_name in evs_ratings:
176 | for system_id in evs_ratings[rating_name]:
177 | rater_ids = evs_rater_ids[rating_name][system_id]
178 | for seg, evs_rating in enumerate(evs_ratings[rating_name][system_id]):
179 | if evs_rating is not None:
180 | if new_ratings[system_id][seg] is not None:
181 | raise ValueError(
182 | f'Found duplicate rating for system/segment: {system_id}/{seg}'
183 | )
184 | new_ratings[system_id][seg] = evs_rating
185 | new_rater_ids[system_id][seg] = rater_ids[seg]
186 | return new_ratings, new_rater_ids
187 |
188 |
189 | def EvalSetRatingsToRatingsList(
190 | evs_ratings: dict[str, dict[str, list[ratings.Rating | None]]],
191 | evs: data.EvalSet,
192 | evs_rater_ids: dict[str, dict[str, list[str | None]]],
193 | rename_raters: dict[str, str] | None = None,
194 | ) -> list[Rating]:
195 | """Convert an EvalSet-style ratings dict to a list of Ratings."""
196 | docs_per_seg = evs.DocsPerSeg()
197 | ratings_list = []
198 | for rating_name in evs_ratings:
199 | for system_id in evs_ratings[rating_name]:
200 | rater_ids = evs_rater_ids[rating_name][system_id]
201 | for seg, evs_rating in enumerate(evs_ratings[rating_name][system_id]):
202 | if evs_rating is None:
203 | continue
204 | rater = rater_ids[seg]
205 | rating = Rating(
206 | source=evs.src[seg],
207 | hypothesis=evs.sys_outputs[system_id][seg],
208 | errors=evs_rating.errors,
209 | document_id=docs_per_seg[seg],
210 | segment_id=seg,
211 | system_id=system_id,
212 | rater_id=rename_raters[rater] if rename_raters else rater,
213 | )
214 | ratings_list.append(rating)
215 | return ratings_list
216 |
--------------------------------------------------------------------------------
/mt_metrics_eval/tasks_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tests for stats."""
16 |
17 | from mt_metrics_eval import tasks
18 | import numpy as np
19 | import unittest
20 |
21 |
22 | class TaskTest(unittest.TestCase):
23 |
24 | def testPostInitPearsonCorr(self):
25 | task = tasks.Task()
26 | ref_task = tasks.Task(gold='mqm', refs={'refA'}, close_refs=set())
27 | self.assertEqual(task, ref_task)
28 |
29 | task = tasks.Task(test_set='wmt21.news')
30 | ref_task = tasks.Task(
31 | test_set='wmt21.news', lang='en-de',
32 | gold='mqm', refs={'refC'}, close_refs=set())
33 | self.assertEqual(task, ref_task)
34 |
35 | def testPostInitAccuracyCorr(self):
36 | task = tasks.Task(
37 | test_set='wmt21.news', lang='en-de,en-ru,zh-en', corr_fcn='accuracy')
38 | ref_task = tasks.Task(
39 | test_set='wmt21.news', lang='en-de,en-ru,zh-en', corr_fcn='accuracy',
40 | gold=['mqm', 'mqm', 'mqm'],
41 | refs=[{'refC'}, {'refA'}, {'refB'}],
42 | close_refs=[set(), set(), set()])
43 | self.assertEqual(task, ref_task)
44 |
45 | task = tasks.Task(
46 | test_set='wmt21.news', lang='en-de', corr_fcn='accuracy')
47 | ref_task = tasks.Task(
48 | test_set='wmt21.news', lang='en-de', corr_fcn='accuracy',
49 | gold=['mqm'], refs=[{'refC'}], close_refs=[set()])
50 | self.assertEqual(task, ref_task)
51 |
52 | def testRunDefault(self):
53 | ref_metrics = [
54 | 'metricx_xxl_MQM_2020', 'COMET-20', 'COMET-22', 'BLEURT-20', 'MATESE',
55 | 'UniTE', 'MS-COMET-22', 'COMETKiwi[noref]', 'SEScore',
56 | 'UniTE-src[noref]', 'YiSi-1', 'COMET-QE[noref]',
57 | 'MS-COMET-QE-22[noref]', 'MEE4', 'MATESE-QE[noref]', 'BERTScore',
58 | 'HWTSC-Teacher-Sim[noref]', 'HuaweiTSC_EE_BERTScore_0.3_With_Human',
59 | 'f200spBLEU', 'chrF', 'BLEU', 'REUSE[noref]']
60 | results = tasks.Task(k=1).Run()
61 | self.assertEqual(results.metrics, ref_metrics)
62 | self.assertAlmostEqual(results.Corr('metricx_xxl_MQM_2020'), 0.8619197)
63 | self.assertEqual(results.Rank('metricx_xxl_MQM_2020'), 1)
64 | self.assertAlmostEqual(results.Corr('REUSE[noref]'), -0.5138621)
65 |
66 | def testRunAccuracy(self):
67 | ref_metrics = [
68 | 'metricx_xxl_MQM_2020', 'COMET-20', 'COMET-22', 'BLEURT-20', 'UniTE',
69 | 'MATESE', 'COMET-QE[noref]', 'MS-COMET-22', 'YiSi-1', 'MEE4',
70 | 'BERTScore', 'UniTE-src[noref]', 'COMETKiwi[noref]', 'SEScore',
71 | 'MS-COMET-QE-22[noref]', 'BLEU', 'chrF', 'f200spBLEU',
72 | 'HuaweiTSC_EE_BERTScore_0.3_With_Human', 'MATESE-QE[noref]',
73 | 'HWTSC-Teacher-Sim[noref]', 'REUSE[noref]']
74 | results = tasks.Task(corr_fcn='accuracy', k=1).Run()
75 | self.assertEqual(results.metrics, ref_metrics)
76 | self.assertAlmostEqual(results.Corr('metricx_xxl_MQM_2020'), 0.8021978)
77 | self.assertEqual(results.Rank('metricx_xxl_MQM_2020'), 1)
78 | self.assertAlmostEqual(results.Corr('REUSE[noref]'), 0.3296703)
79 |
80 | def testNoDraws(self):
81 | results = tasks.Task(k=0).Run()
82 | n = len(results.metrics)
83 | self.assertEqual(results.matrix.tolist(), np.zeros([n, n]).tolist())
84 | self.assertEqual(results.draws_index.tolist(), np.zeros([n, n]).tolist())
85 | self.assertEqual(results.draws_list.tolist(), [])
86 | self.assertEqual(results.Draws(0, 1).tolist(), [])
87 |
88 | def testOneDraw(self):
89 | k = 1
90 | results = tasks.Task(k=k).Run()
91 | n = len(results.metrics)
92 | for i in range(n):
93 | for j in range(i + 1, n):
94 | draws = results.Draws(i, j)
95 | self.assertEqual(len(draws), k) # pylint: disable=g-generic-assert
96 | corr_diff = results.Corr(i) - results.Corr(j)
97 | self.assertGreaterEqual(corr_diff, 0)
98 | null_prob = sum(a - b >= corr_diff for a, b in draws) / k
99 | self.assertAlmostEqual(null_prob, results.matrix[i, j])
100 |
101 |
102 | class TaskResultsTest(unittest.TestCase):
103 |
104 | # TODO(fosterg): Add test for Save/Load.
105 |
106 | def testAttrVals(self):
107 | task = tasks.Task()
108 | res = tasks.TaskResults(task)
109 | attr_vals = res.attr_vals
110 | for attr in tasks.Attributes():
111 | self.assertEqual(attr_vals[attr], f'{task.StrVal(attr)}')
112 |
113 | def testResultsString(self):
114 | results = ({'m1': (0.111111111, 1), 'metric2': (0.222222222, 2)},
115 | np.array([[0, 0.01], [0, 0]]), None, None)
116 | res = tasks.TaskResults(tasks.Task(), results)
117 | self.assertEqual(
118 | res.Str(), 'm1 1 0.1111111 . >\nmetric2 2 0.2222222 . . \n')
119 |
120 | def testRange(self):
121 | results = tasks.Task(k=0, corr_fcn='pearson').Run()
122 | self.assertEqual(results.range, (-1, 1))
123 |
124 | results = tasks.Task(k=0, corr_fcn='accuracy').Run()
125 | self.assertEqual(results.range, (0, 1))
126 |
127 | results = tasks.Task(k=0, corr_fcn='KendallWithTiesOpt').Run()
128 | self.assertEqual(results.range, (0, 1))
129 |
130 | results = tasks.Task(k=0, corr_fcn='KendallWithTiesOpt',
131 | corr_fcn_args={'variant': '23'}).Run()
132 | self.assertEqual(results.range, (-1, 1))
133 |
134 |
135 | class TaskSetTest(unittest.TestCase):
136 |
137 | def testConstruction(self):
138 | attr_combs = {
139 | 'lang': ['en-de', 'en-ru', 'zh-en'],
140 | 'domain': [None, 'conversation', 'ecommerce', 'news', 'social'],
141 | 'level': ['sys', 'seg']
142 | }
143 | taskset = tasks.TaskSet(attr_combs, k=10)
144 | # pylint: disable=g-generic-assert
145 | self.assertEqual(len(taskset), 3 * 5 * 2)
146 |
147 | en_de_count = sum(t.lang == 'en-de' for t in taskset)
148 | self.assertEqual(en_de_count, 10)
149 |
150 | k10_count = sum(t.k == 10 for t in taskset)
151 | self.assertEqual(k10_count, 30)
152 |
153 | taskset = tasks.TaskSet()
154 | self.assertEqual(len(taskset), 0) # pylint: disable=g-generic-assert
155 |
156 | def testAdd(self):
157 | tasks1 = tasks.TaskSet({'lang': ['en-de']})
158 | tasks2 = tasks.TaskSet({'lang': ['en-ru']})
159 | tasks3 = tasks.TaskSet({'lang': ['zh-en']})
160 | sum_tasks = tasks1 + tasks2 + tasks3
161 | all_tasks = tasks.TaskSet({'lang': ['en-de', 'en-ru', 'zh-en']})
162 | self.assertEqual(sum_tasks.tasks, all_tasks.tasks)
163 |
164 | def testRun(self):
165 | taskset = tasks.TaskSet({'corr_fcn': ['pearson', 'accuracy']}, k=1)
166 | res = taskset.Run()
167 | self.assertEqual(len(res), 2) # pylint: disable=g-generic-assert
168 |
169 | ref_pearson = tasks.Task(corr_fcn='pearson', k=1).Run()
170 | self.assertEqual(res.results[0].metrics, ref_pearson.metrics)
171 |
172 | ref_acc = tasks.Task(corr_fcn='accuracy', k=1).Run()
173 | self.assertEqual(res.results[1].metrics, ref_acc.metrics)
174 |
175 |
176 | class TaskSetResultsTest(unittest.TestCase):
177 |
178 | def Results(self, k=1):
179 | taskset = tasks.TaskSet(
180 | {'lang': ['en-de,en-ru,zh-en']}, corr_fcn='accuracy', k=k)
181 | taskset += tasks.TaskSet(
182 | {'lang': ['en-de', 'en-ru', 'zh-en'], 'corr_fcn': ['pearson']}, k=k)
183 | taskset += tasks.TaskSet(
184 | {'lang': ['en-de', 'en-ru'], 'corr_fcn': ['kendall']}, k=k)
185 | return taskset.Run()
186 |
187 | def testSplitByAttr(self):
188 | results = self.Results()
189 | splits = results.SplitByAttr('lang')
190 | self.assertEqual(len(splits), 4) # pylint: disable=g-generic-assert
191 | self.assertEqual(
192 | list(splits.keys()), ['en-de,en-ru,zh-en', 'en-de', 'en-ru', 'zh-en'])
193 | # pylint: disable=g-generic-assert
194 | self.assertEqual(len(splits['en-de,en-ru,zh-en']), 1)
195 | self.assertEqual(len(splits['en-de']), 2)
196 | self.assertEqual(len(splits['en-ru']), 2)
197 | self.assertEqual(len(splits['zh-en']), 1)
198 |
199 | def testAssignWeights(self):
200 | results = self.Results()
201 |
202 | weights = results.AssignWeights(tasks.Attributes())
203 | self.assertEqual(weights, [1/4, 1/8, 1/8, 1/4, 1/8, 1/8])
204 |
205 | weights = results.AssignWeights(['corr_fcn'])
206 | self.assertEqual(weights, [1/3, 1/9, 1/9, 1/9, 1/6, 1/6])
207 |
208 | weights = results.AssignWeights(['test_set'])
209 | self.assertEqual(weights, [1/6] * 6)
210 |
211 | weights = results.AssignWeights([])
212 | self.assertEqual(weights, [1/6] * 6)
213 |
214 | def testAverageRanks(self):
215 | results = self.Results()
216 | ranks = results.AverageRanks()
217 | self.assertEqual(len(ranks), 21) # pylint: disable=g-generic-assert
218 | self.assertEqual(list(ranks.values()), sorted(ranks.values()))
219 | self.assertTrue(all(r >= 1 for r in ranks.values()))
220 |
221 | def testAverageCorrs(self):
222 | results = self.Results()
223 | corrs = results.AverageCorrs()
224 | self.assertEqual(len(corrs), 21) # pylint: disable=g-generic-assert
225 | self.assertEqual(list(corrs.values()), sorted(corrs.values(), reverse=True))
226 | self.assertTrue(all(c >= 0 and c <= 1 for c in corrs.values()))
227 |
228 | def testAverageCorrMatrix(self):
229 | # TODO(fosterg): More explicit test, including handling of variable-length
230 | # draws.
231 | results = self.Results(k=2)
232 | corrs_ranks, sig_matrix = results.AverageCorrMatrix()
233 | self.assertEqual(len(corrs_ranks), 21) # pylint: disable=g-generic-assert
234 | self.assertEqual(sig_matrix.shape, (21, 21))
235 |
236 |
237 | class MetricsTableTest(unittest.TestCase):
238 |
239 | def Columns(self):
240 | return [
241 | {'m1': (0.9, 1), 'm2': (0.8, 1), 'm3': (0.5, 2)},
242 | {'m1': (0.5, 1), 'm2': (-0.2, 2)},
243 | {'m1': (0.2, 2), 'm3': (0.5, 1)},
244 | ]
245 |
246 | def testSmoke(self):
247 | for metrics in ['m1', 'm2', 'm3'], ['m3', 'm1']:
248 | for headers in [], [['m', 'a', 'b', 'c']]:
249 | for fmt in 'tsv', 'text', 'latex':
250 | for which in 'listed', 'union', 'intersection':
251 | for rerank in None, [True, False, True]:
252 | tasks.MetricsTable(
253 | metrics=metrics,
254 | columns=self.Columns(),
255 | column_headers=headers,
256 | fmt=fmt,
257 | which_metrics=which,
258 | rerank=rerank)
259 |
260 |
261 | if __name__ == '__main__':
262 | unittest.main()
263 |
--------------------------------------------------------------------------------
/mt_metrics_eval/tau_optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """An implementation of the tau optimization procedure.
16 |
17 | See https://arxiv.org/abs/2305.14324 for more details on the optimization
18 | routine.
19 | """
20 |
21 | import dataclasses
22 | from typing import Callable
23 |
24 | import numpy as np
25 | import numpy.typing
26 |
27 |
28 | class TauSufficientStats:
29 | """Represents the sufficient statistics for calculating Kendall's tau.
30 |
31 | The two vectors of scores that are correlated are assumed to represent
32 | human and metric scores. Some taus are asymmetric, so we keep the semantics
33 | of the vectors to avoid confusion, which could happen if generic names
34 | were used. If you are calculating the correlation between two metrics, make
35 | sure to understand whether the tau you are computing is symmetric or not.
36 |
37 | Attributes:
38 | con: The number of concordant pairs.
39 | dis: The number of discordant pairs.
40 | ties_human: The number of pairs tied only in the human scores.
41 | ties_metric: The number of pairs tied only in the metric scores.
42 | ties_both: The number of pairs tied in both the human and metric scores.
43 | num_pairs: The total number of pairs.
44 | """
45 |
46 | def __init__(
47 | self,
48 | con: int = 0,
49 | dis: int = 0,
50 | ties_human: int = 0,
51 | ties_metric: int = 0,
52 | ties_both: int = 0,
53 | ):
54 | self.con = con
55 | self.dis = dis
56 | self.ties_human = ties_human
57 | self.ties_metric = ties_metric
58 | self.ties_both = ties_both
59 | self.num_pairs = con + dis + ties_human + ties_metric + ties_both
60 |
61 | def tau_23(self) -> float:
62 | return (
63 | self.con
64 | + self.ties_both
65 | - self.dis
66 | - self.ties_human
67 | - self.ties_metric
68 | ) / self.num_pairs
69 |
70 | def acc_23(self) -> float:
71 | return (self.con + self.ties_both) / self.num_pairs
72 |
73 | def __eq__(self, other: 'TauSufficientStats') -> bool:
74 | return (
75 | self.con,
76 | self.dis,
77 | self.ties_human,
78 | self.ties_metric,
79 | self.ties_both,
80 | ) == (
81 | other.con,
82 | other.dis,
83 | other.ties_human,
84 | other.ties_metric,
85 | other.ties_both,
86 | )
87 |
88 | def __iadd__(self, other: 'TauSufficientStats') -> 'TauSufficientStats':
89 | self.con += other.con
90 | self.dis += other.dis
91 | self.ties_human += other.ties_human
92 | self.ties_metric += other.ties_metric
93 | self.ties_both += other.ties_both
94 | self.num_pairs += other.num_pairs
95 | return self
96 |
97 | def __isub__(self, other: 'TauSufficientStats') -> 'TauSufficientStats':
98 | self.con -= other.con
99 | self.dis -= other.dis
100 | self.ties_human -= other.ties_human
101 | self.ties_metric -= other.ties_metric
102 | self.ties_both -= other.ties_both
103 | self.num_pairs -= other.num_pairs
104 | return self
105 |
106 | def __str__(self) -> str:
107 | return (
108 | '('
109 | + ','.join([
110 | f'C={self.con}',
111 | f'D={self.dis}',
112 | f'T_h={self.ties_human}',
113 | f'T_m={self.ties_metric}',
114 | f'T_hm={self.ties_both}',
115 | ])
116 | + ')'
117 | )
118 |
119 | def __repr__(self):
120 | return str(self)
121 |
122 |
123 | @dataclasses.dataclass
124 | class TauOptimizationResult:
125 | thresholds: list[float]
126 | taus: list[float]
127 | best_threshold: float
128 | best_tau: float
129 |
130 |
131 | class _RankedPair:
132 | """Maintains the metadata for a ranked pair for calculating Kendall's tau.
133 |
134 | Attributes:
135 | row: The index of the row in the N x M matrix of scores that this pair
136 | belongs to.
137 | diff: The absolute difference between metric scores.
138 | stats: The tau sufficient statistics that this pair represents.
139 | tie_stats: The tau sufficient statistics that this pair represents when a
140 | tie is introduced in the metric score.
141 | """
142 |
143 | def __init__(self, h1: float, h2: float, m1: float, m2: float, row: int):
144 | self.row = row
145 | self.diff = abs(m1 - m2)
146 |
147 | # Determine the sufficient stats for the example when treated normally.
148 | if h1 == h2 and m1 == m2:
149 | self.stats = TauSufficientStats(ties_both=1)
150 | elif h1 == h2:
151 | self.stats = TauSufficientStats(ties_human=1)
152 | elif m1 == m2:
153 | self.stats = TauSufficientStats(ties_metric=1)
154 | elif (h1 > h2 and m1 > m2) or (h1 < h2 and m1 < m2):
155 | self.stats = TauSufficientStats(con=1)
156 | else:
157 | self.stats = TauSufficientStats(dis=1)
158 |
159 | # Determine the sufficient stats for the example when a tie is introduced
160 | # in the metric score.
161 | if h1 == h2:
162 | self.tie_stats = TauSufficientStats(ties_both=1)
163 | else:
164 | self.tie_stats = TauSufficientStats(ties_metric=1)
165 |
166 |
167 | def _enumerate_pairs(
168 | human_scores: np.ndarray,
169 | metric_scores: np.ndarray,
170 | sample_rate: float,
171 | filter_nones: bool = True,
172 | ) -> tuple[list[_RankedPair], set[int]]:
173 | """Enumerates pairs for Kendall's tau."""
174 | mat1 = human_scores
175 | mat2 = metric_scores
176 | pairs = []
177 | rows = set()
178 | for row, (r1, r2) in enumerate(zip(mat1, mat2)):
179 | # Filter Nones
180 | if filter_nones:
181 | filt = [
182 | (v1, v2)
183 | for v1, v2 in zip(r1, r2)
184 | if v1 is not None and v2 is not None
185 | ]
186 | if not filt:
187 | continue
188 | r1, r2 = zip(*filt)
189 |
190 | for i in range(len(r1)):
191 | for j in range(i + 1, len(r1)):
192 | if sample_rate == 1.0 or np.random.random() <= sample_rate:
193 | pairs.append(_RankedPair(r1[i], r1[j], r2[i], r2[j], row))
194 | rows.add(row)
195 | return pairs, rows
196 |
197 |
198 | def tau_optimization(
199 | metric_scores: numpy.typing.ArrayLike,
200 | human_scores: numpy.typing.ArrayLike,
201 | tau_fn: Callable[[TauSufficientStats], float],
202 | sample_rate: float = 1.0,
203 | ) -> TauOptimizationResult:
204 | """Runs tau optimization on the metric scores.
205 |
206 | Tau optimization automatically introduces ties into the metric scores to
207 | optimize a tau function. For more details, see
208 | https://arxiv.org/abs/2305.14324.
209 |
210 | The tau value that is calculated and optimized for is the average correlation
211 | (defined by tau_fn) calculated over paired rows in `metric_scores` and
212 | `human_scores`.
213 |
214 | If either `metric_scores` or `human_scores` are missing values, the
215 | corresponding entries should be `None`. In such cases, the input type should
216 | be a Python list or a NumPy array with dtype=object. If `np.nan` is used
217 | instead, the missing values will not be properly removed.
218 |
219 | Args:
220 | metric_scores: An N x M matrix of metric scores.
221 | human_scores: An N x M matrix of human scores.
222 | tau_fn: The tau function to optimize for. This can be a function like
223 | `TauSufficientStats.acc_23`
224 | sample_rate: The proportion of all possible pairs to consider when searching
225 | for epsilon and calculating tau. Must be in the range (0, 1]. Any value
226 | less than 1 will mean the search and optimal tau will be approximations.
227 | The sampling is random and uses `np.random`, so it can be made
228 | deterministic by fixing the NumPy random seed.
229 |
230 | Returns:
231 | The optimization result.
232 | """
233 | if sample_rate <= 0 or sample_rate > 1:
234 | raise ValueError(
235 | f'`sample_rate` must be in the range (0, 1]. Found {sample_rate}'
236 | )
237 |
238 | # Convert the data to a numpy array in case it isn't already.
239 | metric_scores = np.array(metric_scores)
240 | human_scores = np.array(human_scores)
241 |
242 | # The optimization routine expects 2 dimensional matrices. If we are only
243 | # given vectors, create a dummy dimension.
244 | if metric_scores.ndim == 1:
245 | metric_scores = np.expand_dims(metric_scores, 0)
246 | if human_scores.ndim == 1:
247 | human_scores = np.expand_dims(human_scores, 0)
248 |
249 | if human_scores.shape != metric_scores.shape:
250 | raise ValueError('Human and metric scores must have the same shape.')
251 |
252 | pairs, rows = _enumerate_pairs(human_scores, metric_scores, sample_rate)
253 | num_rows = len(rows)
254 |
255 | # Initialize the sufficient stats per row
256 | row_to_stats = {row: TauSufficientStats() for row in rows}
257 | for pair in pairs:
258 | row_to_stats[pair.row] += pair.stats
259 |
260 | # Initialize the optimization. We start with a threshold of 0.0, representing
261 | # no new ties introduced. This is necessary in case there are no ties in
262 | # the metric score at all (meaning epsilon=0 will not be a candidate) and
263 | # introducing any ties is bad.
264 | thresholds = [0.0]
265 | total_tau = sum(tau_fn(stats) for stats in row_to_stats.values())
266 | taus = [total_tau / num_rows]
267 |
268 | # Search all pairs to find the best tau value.
269 | pairs.sort(key=lambda p: p.diff)
270 | for pair in pairs:
271 | # Remove the old tau from the overall sum
272 | total_tau -= tau_fn(row_to_stats[pair.row])
273 |
274 | # Remove this pair from the overall counts, then reintroduce it as a tie.
275 | row_to_stats[pair.row] -= pair.stats
276 | row_to_stats[pair.row] += pair.tie_stats
277 |
278 | # Add the tau back to the overall average
279 | total_tau += tau_fn(row_to_stats[pair.row])
280 |
281 | # Save the new overall for this threshold. If we have already calculated
282 | # a tau for this threshold, overwrite the previous one because each
283 | # threshold should flip every pair with the equivalent diff and the
284 | # previous one did not include this tie.
285 | overall_tau = total_tau / num_rows
286 | if thresholds[-1] == pair.diff:
287 | taus[-1] = overall_tau
288 | else:
289 | thresholds.append(pair.diff)
290 | taus.append(overall_tau)
291 |
292 | # Identify the maximum value and return.
293 | max_index = np.nanargmax(taus)
294 | max_threshold = thresholds[max_index]
295 | max_tau = taus[max_index]
296 | return TauOptimizationResult(thresholds, taus, max_threshold, max_tau)
297 |
--------------------------------------------------------------------------------
/mt_metrics_eval/tau_optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tests for the tau optimization module."""
16 |
17 | from mt_metrics_eval import tau_optimization
18 | import numpy as np
19 | import unittest
20 |
21 |
22 | class TauOptimizationTest(unittest.TestCase):
23 |
24 | def test_zero_threshold_is_best(self):
25 | # There are 6 pairs:
26 | # Metric: (0, 2) (0, 6) (0, 8) (2, 6) (2, 8) (6, 8)
27 | # Human: (1, 2) (1, 3) (1, 4) (2, 3) (2, 4) (3, 4)
28 | metric = [0, 2, 6, 8]
29 | human = [1, 2, 3, 4]
30 |
31 | expected_thresholds = [0, 2, 4, 6, 8]
32 | expected_taus = [6 / 6, 4 / 6, 3 / 6, 1 / 6, 0 / 6]
33 | expected_best_threshold = 0
34 | expected_best_tau = 1.0
35 |
36 | actual = tau_optimization.tau_optimization(
37 | metric, human, tau_optimization.TauSufficientStats.acc_23
38 | )
39 | self.assertEqual(expected_thresholds, actual.thresholds)
40 | self.assertEqual(expected_taus, actual.taus)
41 | self.assertEqual(expected_best_threshold, actual.best_threshold)
42 | self.assertEqual(expected_best_tau, actual.best_tau)
43 |
44 | def test_nonzero_threshold_is_best(self):
45 | # There are 6 pairs:
46 | # Metric: (0, 2) (0, 6) (0, 8) (2, 6) (2, 8) (6, 8)
47 | # Human: (1, 1) (1, 1) (1, 4) (1, 1) (1, 4) (1, 4)
48 | metric = [0, 2, 6, 8]
49 | human = [1, 1, 1, 4]
50 |
51 | expected_thresholds = [0, 2, 4, 6, 8]
52 | expected_taus = [3 / 6, 3 / 6, 4 / 6, 4 / 6, 3 / 6]
53 | expected_best_threshold = 4
54 | expected_best_tau = 4 / 6
55 |
56 | actual = tau_optimization.tau_optimization(
57 | metric, human, tau_optimization.TauSufficientStats.acc_23
58 | )
59 | self.assertEqual(expected_thresholds, actual.thresholds)
60 | self.assertEqual(expected_taus, actual.taus)
61 | self.assertEqual(expected_best_threshold, actual.best_threshold)
62 | self.assertEqual(expected_best_tau, actual.best_tau)
63 |
64 | def test_invalid_sample_rate(self):
65 | metric = [0, 2, 6, 8]
66 | human = [1, 1, 1, 4]
67 |
68 | with self.assertRaises(ValueError):
69 | tau_optimization.tau_optimization(
70 | metric,
71 | human,
72 | tau_optimization.TauSufficientStats.acc_23,
73 | sample_rate=0.0,
74 | )
75 |
76 | with self.assertRaises(ValueError):
77 | tau_optimization.tau_optimization(
78 | metric,
79 | human,
80 | tau_optimization.TauSufficientStats.acc_23,
81 | sample_rate=1.1,
82 | )
83 |
84 | def test_sample_rate_samples_pairs(self):
85 | # Ensures that sample_rate < 1 actually downsamples pairs. The sampling
86 | # is random, so the random seed is fixed.
87 | np.random.seed(123)
88 |
89 | # There are (4 choose 2) = 6 pairs and each pair has a unique difference
90 | # in metric score. We should expect approximately half of the pairs to be
91 | # randomly sampled and use their diffs as thresholds (plus 0, which is
92 | # always considered) when sample_rate=0.5.
93 | metric = [0, 2, 6, 14]
94 | human = [1, 2, 3, 4]
95 | result = tau_optimization.tau_optimization(
96 | metric,
97 | human,
98 | tau_optimization.TauSufficientStats.acc_23,
99 | sample_rate=0.5,
100 | )
101 | self.assertEqual(result.thresholds, [0, 6, 8, 14])
102 |
103 | def test_regression_example(self):
104 | # Tests an example with input matrices. This result has not been manually
105 | # verified, but the test might catch if something changes in the
106 | # optimization routine.
107 | metric = [
108 | [0, 5, 2, 3, 2],
109 | [None, 2, 1, 5, 3],
110 | [9, 1, 5, 3, 8],
111 | [9, 3, 4, 4, 1],
112 | [None, None, None, None, None],
113 | ]
114 | human = [
115 | [1, 5, 2, 1, 1],
116 | [4, 2, 2, 1, 4],
117 | [5, 9, 2, 8, 7],
118 | [7, 6, None, 3, 2],
119 | [None, None, None, None, None],
120 | ]
121 | result = tau_optimization.tau_optimization(
122 | metric,
123 | human,
124 | tau_optimization.TauSufficientStats.acc_23,
125 | )
126 |
127 | expected_thresholds = [0, 1, 2, 3, 4, 5, 6, 7, 8]
128 | expected_taus = [
129 | 0.4666666666666667,
130 | 0.4916666666666668,
131 | 0.38333333333333347,
132 | 0.29166666666666674,
133 | 0.2666666666666667,
134 | 0.2,
135 | 0.15833333333333335,
136 | 0.15833333333333335,
137 | 0.1166666666666667,
138 | ]
139 | expected_best_threshold = 1
140 | expected_best_tau = 0.4916666666666668
141 |
142 | self.assertEqual(expected_thresholds, result.thresholds)
143 | self.assertEqual(expected_taus, result.taus)
144 | self.assertEqual(expected_best_threshold, result.best_threshold)
145 | self.assertEqual(expected_best_tau, result.best_tau)
146 |
147 |
148 | if __name__ == "__main__":
149 | unittest.main()
150 |
--------------------------------------------------------------------------------
/mt_metrics_eval/wmt23_metrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "p30R5QSLqO0R"
7 | },
8 | "source": [
9 | "Colab to reproduce results from the WMT23 metrics shared task"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "O13agE8mqIJ1"
16 | },
17 | "source": [
18 | "## Dependencies"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {
25 | "id": "m23yLnHgEtAA"
26 | },
27 | "outputs": [],
28 | "source": [
29 | "\n",
30 | "# @title Install MTME\n",
31 | "\n",
32 | "!git clone https://github.com/google-research/mt-metrics-eval.git \u0026\u0026 cd mt-metrics-eval \u0026\u0026 pip install ."
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {
39 | "id": "YvjDTSiEp8zn"
40 | },
41 | "outputs": [],
42 | "source": [
43 | "# @title Imports\n",
44 | "\n",
45 | "from mt_metrics_eval import meta_info\n",
46 | "from mt_metrics_eval import data\n",
47 | "from mt_metrics_eval import tasks"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "6TEt4P13H2Iz"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "# @title Download data\n",
59 | "\n",
60 | "data.Download() # Copies about 2G onto local machine."
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {
66 | "id": "7hXLFpZ9uiph"
67 | },
68 | "source": [
69 | "## Reproduce official results"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {
76 | "id": "Le9SHpFHqtmG"
77 | },
78 | "outputs": [],
79 | "source": [
80 | "# @title Generate main results\n",
81 | "\n",
82 | "# Generate main results for primary metrics.\n",
83 | "\n",
84 | "# Setting k=0 suppresses significance testing. Results in the paper were\n",
85 | "# generated with k=1000, which is too slow to run sequentially in a colab.\n",
86 | "main_tasks, main_task_weights = tasks.WMT23(k=0)\n",
87 | "\n",
88 | "# Task names show attributes that define each task.\n",
89 | "for i, task in enumerate(main_tasks):\n",
90 | " print(f'task{i + 1}: {task.name}')\n",
91 | "\n",
92 | "# Takes about 3 minutes.\n",
93 | "main_results = main_tasks.Run()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {
100 | "id": "qBrVhfNlufDI"
101 | },
102 | "outputs": [],
103 | "source": [
104 | "# @title Display main results\n",
105 | "\n",
106 | "# This reproduces Tables 8 and 9 from the shared task paper, modulo signficance\n",
107 | "# results.\n",
108 | "\n",
109 | "# AverageCorrMatrix produces significance clusters and pairwise p-values for the\n",
110 | "# overall average correlation, but requires that the tasks be run with k \u003e 0.\n",
111 | "# AverageCorrs computes the same averages as AverageCorrMatrix but without\n",
112 | "# significance.\n",
113 | "avg_corrs = main_results.AverageCorrs(main_task_weights)\n",
114 | "# avg_corrs, matrix = main_results.AverageCorrMatrix(main_task_weights)\n",
115 | "\n",
116 | "# Use fmt='tsv' to generate tsv format for spreadsheets. This function has\n",
117 | "# many other options to customize output.\n",
118 | "table = main_results.Table(\n",
119 | " metrics=list(avg_corrs),\n",
120 | " initial_column=avg_corrs,\n",
121 | " initial_column_header='avg-corr',\n",
122 | " attr_list=['lang', 'level', 'corr_fcn'],\n",
123 | " nicknames={'KendallWithTiesOpt': 'acc-t'},\n",
124 | " fmt='text',\n",
125 | " baselines_metainfo=meta_info.WMT23)\n",
126 | "print(table)\n"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {
133 | "id": "EqR-lsYuI7tT"
134 | },
135 | "outputs": [],
136 | "source": [
137 | "# @title Generate full results\n",
138 | "\n",
139 | "# Identical to main results except we include contrastive metric submissions.\n",
140 | "\n",
141 | "main_tasks_full, _ = tasks.WMT23(k=0, primary=False)\n",
142 | "\n",
143 | "# Takes about 5 minutes.\n",
144 | "main_results_full = main_tasks_full.Run()"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {
151 | "id": "mtLazjtdJwvk"
152 | },
153 | "outputs": [],
154 | "source": [
155 | "# @title Display full results.\n",
156 | "\n",
157 | "# This reproduces results from Tables 16 and 17 in the paper.\n",
158 | "\n",
159 | "avg_corrs = main_results_full.AverageCorrs(main_task_weights)\n",
160 | "\n",
161 | "# Leading *s indicate contrastive submissions, leading _s indicate baselines.\n",
162 | "table = main_results_full.Table(\n",
163 | " metrics=list(avg_corrs),\n",
164 | " initial_column=avg_corrs,\n",
165 | " initial_column_header='avg-corr',\n",
166 | " attr_list=['lang', 'level', 'corr_fcn'],\n",
167 | " nicknames={'KendallWithTiesOpt': 'acc-t'},\n",
168 | " fmt='text',\n",
169 | " which_metrics='union',\n",
170 | " baselines_metainfo=meta_info.WMT23)\n",
171 | "print(table)\n"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {
178 | "id": "wHx-r4Is1xKn"
179 | },
180 | "outputs": [],
181 | "source": [
182 | "# @title Generate DA results\n",
183 | "\n",
184 | "# Results for all metrics using DA-SQM instead of MQM as gold scores.\n",
185 | "\n",
186 | "# DA scores are available for a wider set of languages than the ones used for\n",
187 | "# the main evaluation. Only en-de and zh-en are common to both.\n",
188 | "da_lps = ['cs-uk', 'de-en', 'en-cs', 'en-de', 'en-ja', 'en-zh', 'ja-en' 'zh-en']\n",
189 | "da_tasks, da_wts = tasks.WMT23(k=0, primary=False, lps=da_lps, gold='da-sqm')\n",
190 | "\n",
191 | "for task in da_tasks:\n",
192 | " print(task.name)\n",
193 | "\n",
194 | "# Takes about 15 minutes.\n",
195 | "da_results = da_tasks.Run()"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {
202 | "id": "L4DUuFd33Zfv"
203 | },
204 | "outputs": [],
205 | "source": [
206 | "# @title Display DA results\n",
207 | "\n",
208 | "# This reproduces results from tables 19 to 27 in the paper.\n",
209 | "\n",
210 | "avg_corrs = da_results.AverageCorrs(da_wts)\n",
211 | "all_da_lps = ','.join(sorted(da_lps))\n",
212 | "\n",
213 | "table = da_results.Table(\n",
214 | " metrics=list(avg_corrs),\n",
215 | " initial_column=avg_corrs,\n",
216 | " initial_column_header='avg-corr',\n",
217 | " attr_list=['lang', 'level', 'corr_fcn'],\n",
218 | " nicknames={'KendallWithTiesOpt': 'acc-t', all_da_lps: 'all'},\n",
219 | " fmt='text',\n",
220 | " which_metrics='union',\n",
221 | " baselines_metainfo=meta_info.WMT23)\n",
222 | "print(table)\n"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {
229 | "id": "NRuoJY8f8gfd"
230 | },
231 | "outputs": [],
232 | "source": [
233 | "# @title Accuracy results, MQM vs DA\n",
234 | "\n",
235 | "# This reproduces results from table 14 in the paper. Note that the two columns\n",
236 | "# are not comparable because they are computed on different sets of languages\n",
237 | "# (in addition to using different gold scores).\n",
238 | "\n",
239 | "acc_mqm = main_results.SplitByAttr('corr_fcn')['accuracy']\n",
240 | "acc_da = da_results.SplitByAttr('corr_fcn')['accuracy']\n",
241 | "acc_mqm_vs_da = acc_mqm + acc_da\n",
242 | "\n",
243 | "table = acc_mqm_vs_da.Table(\n",
244 | " attr_list=['lang'],\n",
245 | " nicknames={all_da_lps: 'all-DA-lps'},\n",
246 | " rerank=[True, True],\n",
247 | " which_metrics='intersection',\n",
248 | " baselines_metainfo=meta_info.WMT23)\n",
249 | "print(table)"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {
255 | "id": "4FAfC_WDg4sy"
256 | },
257 | "source": [
258 | "# Evaluate a new metric\n",
259 | "\n",
260 | "This section shows a worked example of evaluating a new metric online. Another\n",
261 | "possibility is to generate scores offline, write score files to disk, and use\n",
262 | "EvalSet.AddMetricsFromDir() to read them in."
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {
269 | "id": "xHA97g1hjKR9"
270 | },
271 | "outputs": [],
272 | "source": [
273 | "# @title Define the metric\n",
274 | "\n",
275 | "import numpy as np\n",
276 | "\n",
277 | "# Replace this function with your own metric.\n",
278 | "\n",
279 | "def NewMetric(\n",
280 | " level: str,\n",
281 | " lp: str,\n",
282 | " domains: dict[str, list[list[int]]],\n",
283 | " docs: dict[str, list[int]],\n",
284 | " src: list[str],\n",
285 | " ref: list[str],\n",
286 | " hyps: dict[list[str]]\n",
287 | ") -\u003e dict[str, list[float]]:\n",
288 | " \"\"\"\n",
289 | " Generate metric scores.\n",
290 | "\n",
291 | " Args:\n",
292 | " level: Level for which to produce scores, 'sys' or 'seg'.\n",
293 | " lp: Language pair, eg 'en-de'.\n",
294 | " domains: Map from domain name to [[beg, end+1], ...] segment position lists.\n",
295 | " docs: Map from doc name to [beg, end+1] segment positions.\n",
296 | " src: List of source segments.\n",
297 | " ref: List of reference segments.\n",
298 | " hyps: Map from MT system name to output segments for that system.\n",
299 | "\n",
300 | " Returns:\n",
301 | " Map from system name to scores, a list of segment-level scores if level is\n",
302 | " 'seg', or a list containing a single score if level is 'sys'.\n",
303 | " \"\"\"\n",
304 | " # Sample metric just computes a length match between each hypothesis and the\n",
305 | " # reference. It ignores lp, domains, docs, and source.\n",
306 | "\n",
307 | " del lp, domains, docs, src\n",
308 | "\n",
309 | " ref_lens = np.array([len(r) for r in ref])\n",
310 | " scores = {}\n",
311 | " for sysname, hyp in hyps.items():\n",
312 | " hyp_lens = np.array([len(h) for h in hyp])\n",
313 | " deltas = np.abs(ref_lens - hyp_lens) / (ref_lens + 1)\n",
314 | " scores[sysname] = -deltas if level == 'seg' else [-deltas.mean()]\n",
315 | "\n",
316 | " return scores"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {
323 | "id": "cm1F1I3YVCGI"
324 | },
325 | "outputs": [],
326 | "source": [
327 | "# @title Load EvalSets\n",
328 | "\n",
329 | "wmt23_lps = ['en-de', 'he-en', 'zh-en']\n",
330 | "evs_dict = {('wmt23', lp): data.EvalSet('wmt23', lp, True) for lp in wmt23_lps}"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "metadata": {
337 | "id": "5qRH8Y-iMFH5"
338 | },
339 | "outputs": [],
340 | "source": [
341 | "# @title Add metric scores to EvalSets\n",
342 | "\n",
343 | "# Compute scores for each language pair, and add to the appropriate EvalSet.\n",
344 | "# Setting replace=True makes this work if we want to iterate over different\n",
345 | "# versions of the metric.\n",
346 | "\n",
347 | "metric_name = 'lendiff'\n",
348 | "\n",
349 | "for lp in wmt23_lps:\n",
350 | " evs = evs_dict[('wmt23', lp)]\n",
351 | " for refname, ref in evs.all_refs.items():\n",
352 | " sys_scores = NewMetric(\n",
353 | " 'sys', evs.lp, evs.domains, evs.docs, evs.src, ref, evs.sys_outputs)\n",
354 | " seg_scores = NewMetric(\n",
355 | " 'seg', evs.lp, evs.domains, evs.docs, evs.src, ref, evs.sys_outputs)\n",
356 | " evs.AddMetric(metric_name, {refname}, 'sys', sys_scores, replace=True)\n",
357 | " evs.AddMetric(metric_name, {refname}, 'seg', seg_scores, replace=True)\n",
358 | "\n",
359 | "# Add new metric to the primary lists, so it will get picked up when tasks get\n",
360 | "# run with primary=True (avoiding having to evaluate all contrastive\n",
361 | "# submissions as well).\n",
362 | "\n",
363 | "for evs in evs_dict.values():\n",
364 | " evs.SetPrimaryMetrics(evs.primary_metrics | {metric_name})"
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": null,
370 | "metadata": {
371 | "id": "mHtzjRQgXcs2"
372 | },
373 | "outputs": [],
374 | "source": [
375 | "# @title Generate results with new metric\n",
376 | "\n",
377 | "# For a first pass we turn off significance testing.\n",
378 | "\n",
379 | "wmt23_tasks, wts = tasks.WMT23(wmt23_lps, k=0)\n",
380 | "\n",
381 | "# Takes about 3 minutes.\n",
382 | "new_results = wmt23_tasks.Run(eval_set_dict=evs_dict)"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {
389 | "id": "6FDMKPU4d97V"
390 | },
391 | "outputs": [],
392 | "source": [
393 | "# @title Print results\n",
394 | "\n",
395 | "# Results show all primary metrics, along with the new 'lendiff' metric.\n",
396 | "\n",
397 | "avg_corrs = new_results.AverageCorrs(wts)\n",
398 | "\n",
399 | "table = new_results.Table(\n",
400 | " metrics=list(avg_corrs),\n",
401 | " initial_column=avg_corrs,\n",
402 | " initial_column_header='avg-corr',\n",
403 | " attr_list=['lang', 'level', 'corr_fcn'],\n",
404 | " nicknames={'KendallWithTiesOpt': 'acc-t'},\n",
405 | " fmt='text',\n",
406 | " baselines_metainfo=meta_info.WMT23)\n",
407 | "\n",
408 | "print(table)\n"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": null,
414 | "metadata": {
415 | "id": "q0elPs7kuXFO"
416 | },
417 | "outputs": [],
418 | "source": [
419 | "# @title Compare with significance\n",
420 | "\n",
421 | "# For speed reasons, limit comparison to the two metrics that bracket lendiff\n",
422 | "# in the average-correlation ranking.\n",
423 | "for evs in evs_dict.values():\n",
424 | " evs.SetPrimaryMetrics({'Random-sysname', 'lendiff', 'eBLEU'})\n",
425 | "\n",
426 | "# Run the significance test. Set k=1000 for a more realistic comparison. This\n",
427 | "# takes about 2 minutes with k=50.\n",
428 | "wmt23_tasks, wts = tasks.WMT23(wmt23_lps, k=50)\n",
429 | "new_results = wmt23_tasks.Run(eval_set_dict=evs_dict)\n"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": null,
435 | "metadata": {
436 | "id": "OkFH9_xJwjOF"
437 | },
438 | "outputs": [],
439 | "source": [
440 | "# @title Print significance results\n",
441 | "\n",
442 | "avg_corrs, matrix = new_results.AverageCorrMatrix(main_task_weights)\n",
443 | "\n",
444 | "table = new_results.Table(\n",
445 | " metrics=list(avg_corrs),\n",
446 | " initial_column=avg_corrs,\n",
447 | " initial_column_header='avg-corr',\n",
448 | " attr_list=['lang', 'level', 'corr_fcn'],\n",
449 | " nicknames={'KendallWithTiesOpt': 'acc-t'},\n",
450 | " fmt='text',\n",
451 | " baselines_metainfo=meta_info.WMT23)\n",
452 | "\n",
453 | "# The table indicates that lendiff and eBLEU are in the same significance\n",
454 | "# cluster ahead of Random-sysname.\n",
455 | "print(table)\n",
456 | "print()\n",
457 | "\n",
458 | "# Print the p-value matrix for the three pairwise comparisons used to assign\n",
459 | "# significance clusters.\n",
460 | "print(tasks.MatrixString(avg_corrs, matrix, probs=True))\n"
461 | ]
462 | }
463 | ],
464 | "metadata": {
465 | "colab": {
466 | "last_runtime": {
467 | "build_target": "//learning/grp/tools/ml_python:ml_notebook",
468 | "kind": "private"
469 | },
470 | "private_outputs": true,
471 | "provenance": [
472 | {
473 | "file_id": "1UgUZ35EdmwwuDljJMtlz5vAaOT4blX8J",
474 | "timestamp": 1699484321090
475 | }
476 | ],
477 | "toc_visible": true
478 | },
479 | "kernelspec": {
480 | "display_name": "Python 3",
481 | "name": "python3"
482 | },
483 | "language_info": {
484 | "name": "python"
485 | }
486 | },
487 | "nbformat": 4,
488 | "nbformat_minor": 0
489 | }
490 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # coding=utf-8
16 | # Copyright 2021 Google LLC
17 | #
18 | # Licensed under the Apache License, Version 2.0 (the "License");
19 | # you may not use this file except in compliance with the License.
20 | # You may obtain a copy of the License at
21 | #
22 | # http://www.apache.org/licenses/LICENSE-2.0
23 | #
24 | # Unless required by applicable law or agreed to in writing, software
25 | # distributed under the License is distributed on an "AS IS" BASIS,
26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 | # See the License for the specific language governing permissions and
28 | # limitations under the License.
29 | """Setup script for mt-metrics-eval.
30 |
31 | This script will allow pip-installing as a Python module.
32 | """
33 |
34 | import setuptools
35 |
36 | with open("README.md", "r") as fh:
37 | long_description = fh.read()
38 |
39 | install_requires = [
40 | "apache_beam",
41 | "dacite",
42 | "numpy",
43 | "scipy",
44 | "absl-py",
45 | ]
46 |
47 | setuptools.setup(
48 | name="mt-metrics-eval",
49 | version="0.0.3",
50 | author="George Foster",
51 | description="Toolkit for evaluating Machine Translation metrics.",
52 | long_description=long_description,
53 | long_description_content_type="text/markdown",
54 | url="https://github.com/google-research/mt-metrics-eval",
55 | packages=setuptools.find_packages(),
56 | classifiers=[
57 | "Programming Language :: Python :: 3",
58 | "Operating System :: OS Independent",
59 | "Intended Audience :: Developers",
60 | "Intended Audience :: Education",
61 | "Intended Audience :: Science/Research",
62 | "License :: OSI Approved :: Apache Software License",
63 | "Programming Language :: Python :: 3",
64 | ],
65 | license="Apache 2.0",
66 | python_requires=">=3",
67 | install_requires=install_requires)
68 |
--------------------------------------------------------------------------------