├── .github
└── workflows
│ └── unittest.yml
├── .gitignore
├── LICENSE
├── README.md
├── experiments
├── README.md
├── bertsch-et-al-2023-mbr
│ ├── README.md
│ ├── results
│ │ └── figures
│ │ │ ├── CNN DM original.png
│ │ │ └── CNN DM reproduction.png
│ └── run_experiment.py
├── chrf-vs-fastchrf
│ ├── README.md
│ └── run_experiment.py
├── freitag-et-al-2023-epsilon
│ ├── README.md
│ ├── results
│ │ └── figures
│ │ │ ├── All Results DE–EN (original).png
│ │ │ ├── All Results DE–EN (reproduction).png
│ │ │ ├── All Results EN–DE (original).png
│ │ │ ├── All Results EN–DE (reproduction).png
│ │ │ ├── Main Comparison DE–EN (original).png
│ │ │ ├── Main Comparison DE–EN (reproduction).png
│ │ │ ├── Main Comparison EN–DE (original).png
│ │ │ └── Main Comparison EN–DE (reproduction).png
│ └── run_experiment.py
├── müller-sennrich-2021-understanding
│ ├── README.md
│ ├── results
│ │ └── figures
│ │ │ ├── AZE–ENG (original).png
│ │ │ ├── AZE–ENG (reproduction).png
│ │ │ ├── BEL–RUS (original).png
│ │ │ ├── BEL–RUS (reproduction).png
│ │ │ ├── DAN–EPO (original).png
│ │ │ ├── DAN–EPO (reproduction).png
│ │ │ ├── DEU–FRA (original).png
│ │ │ └── DEU–FRA (reproduction).png
│ └── run_experiment.py
├── reference_aggregation
│ ├── README.md
│ ├── baseline_beam_search.py
│ ├── baseline_epsilon_sampling.py
│ ├── experiment_utils.py
│ ├── fairseq_utils.py
│ ├── generate_samples.py
│ ├── mbr_utils.py
│ ├── plot_accuracy.py
│ ├── requirements.txt
│ ├── results
│ │ ├── chrf.log
│ │ ├── comet-xl.log
│ │ ├── comet22.log
│ │ └── cometinho.log
│ ├── run_mbr.py
│ ├── scripts
│ │ ├── benchmark_time.sh
│ │ ├── evaluate-bleurt.ipynb
│ │ ├── evaluate-chrf.sh
│ │ ├── evaluate-comet.sh
│ │ ├── plot_accuracy_reverse.py
│ │ ├── print_data_stats_table.py
│ │ └── save_src_and_ref.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_beam_search.py
│ │ ├── test_chrf.py
│ │ ├── test_comet.py
│ │ ├── test_epsilon_sampling.py
│ │ ├── test_generate_samples.py
│ │ ├── test_run_mbr.py
│ │ ├── test_save_src_and_ref.py
│ │ ├── test_testset.py
│ │ └── test_validation.py
│ └── validation.py
└── requirements.txt
├── minimum-bayes-risk-decoding.png
├── pyproject.toml
├── requirements-dev.txt
├── requirements-test.txt
├── src
└── mbr
│ ├── __init__.py
│ ├── generation
│ ├── __init__.py
│ ├── configuration_utils.py
│ └── utils.py
│ ├── metrics
│ ├── __init__.py
│ ├── base.py
│ ├── comet.py
│ └── fastchrf.py
│ └── modeling.py
└── tests
├── __init__.py
├── test_config.py
├── test_generate.py
├── test_metrics.py
└── test_pipelines.py
/.github/workflows/unittest.yml:
--------------------------------------------------------------------------------
1 | name: unittest
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | build:
11 |
12 | runs-on: ubuntu-latest
13 | strategy:
14 | matrix:
15 | python-version: ["3.11"]
16 |
17 | steps:
18 | - uses: actions/checkout@v3
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v3
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
27 | pip install .
28 | pip install -r requirements-test.txt
29 | - name: Lint with flake8
30 | run: |
31 | pip install flake8
32 | # stop the build if there are Python syntax errors or undefined names
33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude experiments
34 | - name: Test
35 | run: SKIP_SLOW_TESTS=True python -m unittest discover -s tests
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # mbr 🔥
2 | [](https://github.com/ZurichNLP/mbr/actions/workflows/unittest.yml)
3 | [](https://pypi.python.org/pypi/mbr/)
4 |
5 | **mbr** adds Sampling-based Minimum Bayes Risk decoding to [Hugging Face transformers](https://github.com/huggingface/transformers). Originally proposed by [Eikema & Aziz (2022)](https://aclanthology.org/2022.emnlp-main.754/), this technique is a risk-minimizing algorithm for generating text with a language model. This repository implements several optimizations for MBR decoding. Most notably, **mbr** introduces reference aggregation [Vamvas & Sennrich (2024)](https://arxiv.org/abs/2402.04251).
6 |
7 | Pronounce: _ember_ /ˈɛm.bɚ/
8 |
9 | ## Installation
10 |
11 | ```bash
12 | pip install mbr
13 | ```
14 |
15 | Requirements:
16 | - Python >= 3.9
17 | - PyTorch
18 | - Hugging Face transformers < 4.39
19 |
20 | ## Usage
21 | The main components of **mbr** are:
22 | - `mbr.MBRGenerationMixin`: overrides a model's `generate` method to add MBR decoding.
23 | - `mbr.MBRGenerationConfig`: specifies the parameters of MBR decoding, e.g., the number of samples to generate and the metric to optimize.
24 |
25 | ### 1. Load a Hugging Face transformers model
26 | Models need to inherit from `MBRGenerationMixin` for MBR decoding to work. Here's two ways to achieve this, using the Llama model as an example:
27 |
28 | **Variant A:**
29 |
30 | ```python
31 | from transformers import LlamaForCausalLM
32 |
33 | from mbr import MBRGenerationMixin
34 |
35 | class MBRLlamaForCausalLM(MBRGenerationMixin, LlamaForCausalLM):
36 | pass
37 | ```
38 |
39 | Then, you can use `MBRLlamaForCausalLM` as you would use `LlamaForCausalLM`:
40 |
41 | ```python
42 | model = MBRLlamaForCausalLM.from_pretrained(...)
43 | ```
44 |
45 | **Variant B:**
46 |
47 | ```python
48 | from mbr import MBR
49 | model = MBR(LlamaForCausalLM).from_pretrained(...)
50 | ```
51 |
52 | ### 2. Configure MBR decoding
53 |
54 | Create an `MBRConfig` object to pass to the model's `generate` method:
55 |
56 | ```python
57 | from mbr import MBRConfig
58 |
59 | mbr_config = MBRConfig(
60 | num_samples=10,
61 | metric="chrf",
62 | )
63 | ```
64 |
65 | ### 3. Generate text as usual
66 | Call the model's `generate` method directly, or use the Pipeline API. Make sure to pass the `mbr_config`, as well as the model's tokenizer.
67 |
68 | ```python
69 | from transformers import pipeline
70 |
71 | generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
72 | output = generator("Hello,", mbr_config=mbr_config, tokenizer=tokenizer)
73 | ```
74 |
75 | ## How MBR decoding works
76 | The following research papers, among many others, provide a description of Sampling-based Minimum Bayes Risk decoding:
77 | - [Sampling-Based Approximations to Minimum Bayes Risk Decoding for Neural Machine Translation](https://aclanthology.org/2022.emnlp-main.754) (Eikema & Aziz, EMNLP 2022)
78 | - [Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation](https://aclanthology.org/2021.acl-long.22) (Müller & Sennrich, ACL-IJCNLP 2021)
79 |
80 | In practice, MBR decoding is most commonly implemented as follows (on the example of machine translation):
81 |
82 | - Instead of searching for the single most probable output sequence (e.g., using beam search), generate a number of samples.
83 | - Score each sample against the other samples using a metric (e.g., BLEU).
84 | - Return the sample with the highest score. Intuitively, this can be seen as returning the median of all samples.
85 |
86 |
87 |
88 |
89 | The terminology around MBR decoding varies:
90 |
91 | | Term used in this codebase | Alternative terms |
92 | |:---------------------------|:-----------------------------------------------------|
93 | | samples | candidates, hypotheses |
94 | | references | pseudo-references, evidence |
95 | | metric score | expected utility
(negative) expected risk, error |
96 |
97 | ## Details
98 |
99 | ### Configuring the sampling
100 | The generation of the samples can be customized by passing a `generation_config` to the `generate` method or to the pipeline call:
101 |
102 | ```python
103 | from transformers import GenerationConfig
104 |
105 | generation_config = GenerationConfig.from_pretrained("mymodel",
106 | do_sample=True,
107 | num_beams=1,
108 | epsilon_cutoff=0.02,
109 | )
110 | model.generate(..., generation_config=generation_config)
111 | ```
112 |
113 | ### Separate set of references
114 | By default, the samples themselves are used a references (or a subset of the samples if `num_references` is smaller than `num_samples`).
115 |
116 | You could also sample the reference set independently, using a custom generation config for the references:
117 |
118 | ```python
119 | from transformers import GenerationConfig
120 |
121 | references_config = GenerationConfig.from_pretrained("mymodel",
122 | do_sample=True,
123 | num_beams=1,
124 | top_p=0.9,
125 | )
126 | model.generate(..., references_config=references_config)
127 | ```
128 |
129 | ### Choosing a metric
130 | By default, **mbr** uses [fastChrF](https://github.com/jvamvas/fastChrF), which is optimized for efficient comparison of many samples to many references.
131 |
132 | You can also plug in metrics from the [**Hugging Face Evaluate**](https://github.com/huggingface/evaluate) library.
133 |
134 | A full list of metrics is found [here](https://huggingface.co/metrics). Some typical choices are:
135 | - [COMET](https://huggingface.co/spaces/evaluate-metric/comet) ([Rei et al., 2020](https://aclanthology.org/2020.emnlp-main.213/))
136 | - [BLEURT](https://huggingface.co/spaces/evaluate-metric/bleurt) ([Sellam et al., 2020](https://aclanthology.org/2020.acl-main.704))
137 |
138 | To use a metric from Hugging Face, either specify the metric's name (e.g., `"comet"`, `"bleurt"`) or pass an `evaluate.Metric` object directly.
139 |
140 | Since different metrics output differently structured dicts, you need to specify the `metric_output_field` that should be used as the metric score.
141 |
142 | ```python
143 | from evaluate import load
144 |
145 | metric = load('bleu')
146 | mbr_config = MBRGenerationConfig(
147 | metric=metric,
148 | metric_output_field="bleu", # the BLEU metric returns a dict with a "bleu" field
149 | ...
150 | )
151 | ```
152 |
153 | ### Customizing the metric computation
154 | Internally, **mbr** will call the metric's `compute` method to calculate the metric score for each sample.
155 |
156 | By default, **mbr** will call `compute` separately for each sample–reference pair.
157 | Since this requires many `compute` calls, it can make sense to optimize the metric computation. Different metrics will require different optimization strategies.
158 | To override the default way of calling the metric, define a `MetricRunner` class and pass it to the `generate` method:
159 |
160 | ```python
161 | from mbr import MetricRunner
162 |
163 | class MyMetricRunner(MetricRunner):
164 |
165 | def __call__(self,
166 | input_ids: torch.LongTensor,
167 | sample_ids: Tuple[torch.LongTensor],
168 | reference_ids: Tuple[torch.LongTensor],
169 | ) -> torch.FloatTensor:
170 | ... # TODO: implement your efficient metric computation here
171 |
172 | model.generate(..., metric_runner=MyMetricRunner())
173 | ```
174 |
175 | For **COMET**, an optimized implementation is already provided in `CometMetricRunner`:
176 |
177 | ```python
178 | from mbr.metrics.comet import CometMetricRunner
179 |
180 | mbr_config = MBRGenerationConfig(
181 | ...,
182 | metric="comet",
183 | metric_output_field="mean_score",
184 | )
185 |
186 | metric_runner = CometMetricRunner(mbr_config, tokenizer)
187 | model.generate(..., metric_runner=metric_runner)
188 | ```
189 |
190 | ### Optimizations
191 | MBR decoding is notoriously slow. **mbr** implements some optimizations:
192 | - Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling.
193 | - Optimized ChrF metric: [fastChrF](https://github.com/jvamvas/fastChrF) is used by default, which is a streamlined ChrF variant for MBR, implemented in Rust.
194 | - Cached metrics: Most metrics are computed only once for each unique sample–reference pair (since there will be duplicate samples and references).
195 | - Optimized COMET metric: Inspired by [Amrhein & Sennrich (2022)](https://aclanthology.org/2022.aacl-main.83/), `CometMetricRunner` caches sequence embeddings and reuses them for all pairwise comparisons.
196 | - Reference aggregation for COMET ([Vamvas & Sennrich, 2024](https://arxiv.org/abs/2402.04251)): Consider using `mbr.metrics.comet.AggregateCometMetricRunner` instead of the default `CometMetricRunner` if you have many references.
197 |
198 | ## Example scripts
199 |
200 | The [experiments](experiments) directory contains the code for reproductions of experiments from the following papers:
201 |
202 | - [MBR for (low-resource) machine translation](experiments/müller-sennrich-2021-understanding) ([Müller & Sennrich, 2021](https://aclanthology.org/2021.acl-long.22/))
203 | - [MBR with neural metrics and epsilon sampling for machine translation](experiments/freitag-et-al-2023-epsilon) ([Freitag et al., 2023](https://arxiv.org/abs/2305.09860))
204 | - [MBR for summarization](experiments/bertsch-et-al-2023-mbr) ([Bertsch et al., 2023](https://arxiv.org/abs/2310.01387))
205 |
206 | ### Code for research papers
207 | - [Code for the paper "Linear-time Minimum Bayes Risk Decoding with Reference Aggregation" (Vamvas & Sennrich, 2024)](experiments/reference_aggregation)
208 |
209 | ## Related projects
210 | - https://github.com/roxot/mbr-nmt: Original implementation ([demo](https://colab.research.google.com/github/probabll/demo-mbr-nmt/blob/main/German-English.ipynb))
211 | - https://github.com/ZurichNLP/understanding-mbr: MBR with Sockeye
212 | - https://github.com/ZurichNLP/mbr-sensitivity and https://github.com/Unbabel/COMET#minimum-bayes-risk-decoding: COMET metric for MBR
213 | - https://github.com/rainavyas/mbr_gec: MBR for Grammatical Error Correction
214 |
215 | ## Changelog
216 |
217 | - v0.3.0 (draft)
218 | - New feature: Reference Aggregation ([Vamvas & Sennrich, 2024](https://arxiv.org/abs/2402.04251)):
219 | - Set [fastChrF](https://github.com/jvamvas/fastChrF) with reference aggregation as default metric
220 | - Add `AggregateCometMetricRunner` to allow for reference aggregation with COMET
221 | - **Bugfix**: Disable dropout for COMET metric
222 |
223 | - v0.2.0
224 | - **Breaking change:** Rename `MBRGenerationConfig` to `MBRConfig`
225 | - **Breaking change:** `MetricRunner` now returns a `MetricOutput` dict instead of the raw tensor of scores.
226 | - Make the size of the metric cache configurable via `MBRConfig.metric_cache_size`
227 | - Allow that the number of references can be larger than the number of samples (if generated separately from the samples).
228 | - Remove `GenerationConfig` as parent class of `MBRConfig`
229 |
230 | ## Citation
231 | When using this code for research, please cite the following paper:
232 |
233 | ```bibtex
234 | @misc{vamvas-sennrich-2024-linear,
235 | title={Linear-time Minimum Bayes Risk Decoding with Reference Aggregation},
236 | author={Jannis Vamvas and Rico Sennrich},
237 | year={2024},
238 | eprint={2402.04251},
239 | archivePrefix={arXiv},
240 | primaryClass={cs.CL}
241 | }
242 | ```
243 |
--------------------------------------------------------------------------------
/experiments/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Experiments using the [**mbr**](https://github.com/ZurichNLP/mbr) package
3 |
4 | **Code for research papers:**
5 | - Linear-time Minimum Bayes Risk Decoding with Reference Aggregation (Vamvas & Sennrich, 2024)
6 |
7 | **Reproductions of experiments from related research papers:**
8 |
9 | - It's MBR All the Way Down: Modern Generation Techniques Through the Lens of Minimum Bayes Risk (Bertsch et al., 2023)
10 | - Epsilon Sampling Rocks: Investigating Sampling Strategies for Minimum Bayes Risk Decoding for Machine Translation (Freitag et al., 2023)
11 | - Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation (Müller & Sennrich, ACL-IJCNLP 2021)
12 |
13 | **Other experiments**
14 | - Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR.
15 |
--------------------------------------------------------------------------------
/experiments/bertsch-et-al-2023-mbr/README.md:
--------------------------------------------------------------------------------
1 | This directory uses the [**mbr**](https://github.com/ZurichNLP/mbr) package to reproduce an experiment from the paper [It's MBR All the Way Down: Modern Generation Techniques Through the Lens of Minimum Bayes Risk](https://arxiv.org/abs/2310.01387) (Bertsch et al., 2023).
2 |
3 | ## Setup
4 | * Task: Summarization
5 | * Language: English
6 | * Model: facebook/bart-large-cnn ([Lewis et al., 2020](https://aclanthology.org/2020.acl-main.703/))
7 | * MBR metric: ROUGE-1 ([Lin, 2004](https://aclanthology.org/W04-1013/))
8 | * Number of samples: 30
9 | * Number of references: 30
10 | * Sampling approach: sampling with temperature 0.5
11 | * Reference sampling approach: sampling with temperature 1.0
12 | * Test set: CNN/DailyMail ([Nallapati et al., 2016](https://aclanthology.org/K16-1028/))
13 | * Evaluation metric: ROUGE-1
14 | * Baselines: greedy decoding, beam search
15 |
16 | ## Results
17 |
18 | | Paper | Reproduction |
19 | |:---------------------------------------------------------------:|:---:|
20 | |  |  |
21 |
--------------------------------------------------------------------------------
/experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM original.png
--------------------------------------------------------------------------------
/experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM reproduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM reproduction.png
--------------------------------------------------------------------------------
/experiments/bertsch-et-al-2023-mbr/run_experiment.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from pathlib import Path
3 |
4 | import evaluate
5 | import jsonlines
6 | import torch
7 | from datasets import load_dataset
8 | from tqdm import tqdm
9 | from transformers import BartForConditionalGeneration, AutoTokenizer, pipeline, GenerationConfig
10 |
11 | from mbr import MBR, MBRConfig
12 |
13 | results_file = jsonlines.open(Path(__file__).parent / f"results.jsonl", "w")
14 |
15 | model_name = "facebook/bart-large-cnn"
16 | model = MBR(BartForConditionalGeneration).from_pretrained(model_name)
17 | tokenizer = AutoTokenizer.from_pretrained(model_name)
18 | summarization_pipeline = pipeline(
19 | "summarization",
20 | model=model,
21 | tokenizer=tokenizer,
22 | device=(0 if torch.cuda.is_available() else -1),
23 | )
24 | evaluation_metric_rouge = evaluate.load("rouge")
25 |
26 | dataset = load_dataset("cnn_dailymail", "3.0.0", split="test")
27 |
28 | # MBR
29 | mbr_config = MBRConfig()
30 | mbr_config.num_samples = 30
31 | mbr_config.num_references = 30
32 | mbr_config.metric = "rouge"
33 | mbr_config.metric_output_field = "rouge1"
34 | # efficiency settings
35 | mbr_config.metric_kwargs = {"rouge_types": ("rouge1",), "use_aggregator": False}
36 |
37 | generation_config = GenerationConfig.from_pretrained(model_name)
38 | generation_config.do_sample = True
39 | generation_config.num_beams = 1
40 | generation_config.temperature = 0.5
41 | generation_config.early_stopping = False
42 | generation_config.length_penalty = 1.0
43 |
44 | references_config = GenerationConfig.from_pretrained(model_name)
45 | references_config.do_sample = True
46 | references_config.num_beams = 1
47 | references_config.temperature = 1.0
48 | references_config.early_stopping = False
49 | references_config.length_penalty = 1.0
50 |
51 | summaries = []
52 | outputs = summarization_pipeline(
53 | dataset["article"],
54 | mbr_config=mbr_config,
55 | generation_config=generation_config,
56 | references_config=references_config,
57 | tokenizer=tokenizer,
58 | truncation=True,
59 | progress_bar=True,
60 | batch_size=32,
61 | )
62 | for output in outputs:
63 | summaries.append(output["summary_text"])
64 | rouge_score = evaluation_metric_rouge.compute(predictions=summaries, references=dataset["highlights"])
65 | results_file.write({
66 | "method": "mbr rouge-1",
67 | "rouge": rouge_score,
68 | "summaries": summaries,
69 | })
70 |
71 | # Baselines
72 | model = BartForConditionalGeneration.from_pretrained(model_name).to(summarization_pipeline.device)
73 | summarization_pipeline.model = model
74 | base_generation_config = GenerationConfig.from_pretrained(model_name)
75 | generation_configs = {}
76 |
77 | # greedy
78 | generation_config = deepcopy(base_generation_config)
79 | generation_config.do_sample = False
80 | generation_config.num_beams = 1
81 | generation_config.early_stopping = False
82 | generation_config.length_penalty = 1.0
83 | generation_configs["greedy"] = generation_config
84 |
85 | # beam search k=5
86 | generation_config = deepcopy(base_generation_config)
87 | generation_config.do_sample = False
88 | generation_config.num_beams = 5
89 | generation_configs["beam search k=5"] = generation_config
90 |
91 | # beam search k=10
92 | generation_config = deepcopy(base_generation_config)
93 | generation_config.do_sample = False
94 | generation_config.num_beams = 10
95 | generation_configs["beam search k=10"] = generation_config
96 |
97 | for method, generation_config in generation_configs.items():
98 | print(method, flush=True)
99 | summaries = []
100 | outputs = summarization_pipeline(
101 | dataset["article"],
102 | generation_config=generation_config,
103 | truncation=True,
104 | batch_size=1,
105 | )
106 | for output in tqdm(outputs):
107 | summaries.append(output["summary_text"])
108 | rouge_score = evaluation_metric_rouge.compute(predictions=summaries, references=dataset["highlights"])
109 | results_file.write({
110 | "method": method,
111 | "rouge": rouge_score,
112 | "summaries": summaries,
113 | })
114 |
115 | results_file.close()
116 |
--------------------------------------------------------------------------------
/experiments/chrf-vs-fastchrf/README.md:
--------------------------------------------------------------------------------
1 | Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR.
2 |
3 | ## Setup
4 | * Task: Machine translation
5 | * Translation directions: en–de, de–en, en–ru, ru–en
6 | * Model: [facebook/wmt19-*](https://huggingface.co/facebook/wmt19-en-de) ([Ng et al., 2019](https://aclanthology.org/W19-5333/)).
7 | * MBR metrics: `fastchrf.pairwise_chrf` (a fast implementation of standard ChrF) and `fastchrf.aggregate_chrf` (a streamlined ChrF variant for MBR)
8 | * Number of samples: 256
9 | * Sampling approach: epsilon sampling with ε=0.02
10 | * Samples and references are the same
11 | * Test set: newstest2019
12 | * Evaluation metrics: chrF ([sacreBLEU](https://github.com/mjpost/sacrebleu)) and COMET-22 ([Rei et al., 2022](https://aclanthology.org/2022.wmt-1.52/))
13 | * Baseline: beam search with beam size 4
14 |
15 | ## Results
16 | | Language Pair | Method | ChrF | COMET | duration (s) |
17 | |---------------|--------------------------------------|---------:|----------:|-------------:|
18 | | en-de | MBR with `fastchrf.pairwise_chrf` | 67.7 | 0.867 | 7798 |
19 | | en-de | MBR with `fastchrf.aggregate_chrf` | 67.7 | 0.867 | 7480 |
20 | | en-de | Beam search | 67.7 | 0.868 | 62 |
21 | | de-en | MBR with `fastchrf.pairwise_chrf` | 65.4 | 0.851 | 6894 |
22 | | de-en | MBR with `fastchrf.aggregate_chrf` | 65.6 | 0.850 | 6849 |
23 | | de-en | Beam search | 65.1 | 0.851 | 53 |
24 | | en-ru | MBR with `fastchrf.pairwise_chrf` | 57.5 | 0.862 | 7802 |
25 | | en-ru | MBR with `fastchrf.aggregate_chrf` | 57.5 | 0.862 | 7465 |
26 | | en-ru | Beam search | 56.9 | 0.863 | 64 |
27 | | ru-en | MBR with `fastchrf.pairwise_chrf` | 64.2 | 0.847 | 7541 |
28 | | ru-en | MBR with `fastchrf.aggregate_chrf` | 64.3 | 0.848 | 6689 |
29 | | ru-en | Beam search | 63.5 | 0.847 | 61 |
30 | | **Average** | **MBR with `fastchrf.pairwise_chrf`** | **63.7** | **0.857** | **7509** |
31 | | **Average** | **MBR with `fastchrf.aggregate_chrf`** | **63.7** | **0.857** | **7121** |
32 | | **Average** | **Beam search** | **63.3** | **0.857** | **60** |
--------------------------------------------------------------------------------
/experiments/chrf-vs-fastchrf/run_experiment.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | from copy import deepcopy
4 | from pathlib import Path
5 |
6 | import evaluate
7 | import jsonlines
8 | import sacrebleu
9 | import torch
10 | from datasets import load_dataset
11 | from tqdm import tqdm
12 | from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig
13 |
14 | from mbr import MBR, MBRConfig
15 |
16 | language_pair = sys.argv[1]
17 | assert language_pair in ["de-en", "en-de", "en-ru", "ru-en"]
18 |
19 | batch_size = 32
20 |
21 | results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w")
22 |
23 | model_name = f"facebook/wmt19-{language_pair}"
24 | model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name)
25 | tokenizer = AutoTokenizer.from_pretrained(model_name)
26 | mt_pipeline = pipeline(
27 | "translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1],
28 | model=model,
29 | tokenizer=tokenizer,
30 | device=(0 if torch.cuda.is_available() else -1),
31 | )
32 | evaluation_metric_chrf = evaluate.load("chrf")
33 | evaluation_metric_comet = evaluate.load("comet", "Unbabel/wmt22-comet-da")
34 |
35 | src_path = sacrebleu.get_source_file("wmt19", language_pair)
36 | ref_path = sacrebleu.get_reference_files("wmt19", language_pair)[0]
37 | dataset = load_dataset("text", data_files={"test": src_path})
38 | references = Path(ref_path).read_text().splitlines()
39 | assert len(dataset["test"]) == len(references)
40 |
41 | # MBR
42 | generation_config = GenerationConfig.from_pretrained(model_name)
43 | generation_config.do_sample = True
44 | generation_config.num_beams = 1
45 | generation_config.early_stopping = False
46 | generation_config.epsilon_cutoff = 0.02
47 |
48 | base_mbr_config = MBRConfig(
49 | num_samples=256,
50 | num_references=256,
51 | )
52 | base_mbr_config.metric_cache_size = batch_size * base_mbr_config.num_samples * base_mbr_config.num_references
53 | mbr_configs = {}
54 |
55 | # MBR with fastchrf.pairwise_chrf
56 | mbr_config = deepcopy(base_mbr_config)
57 | mbr_config.metric = "fastchrf-pairwise"
58 | mbr_configs["MBR with fastchrf.pairwise_chrf"] = mbr_config
59 |
60 | # MBR with fastchrf.aggregate_chrf
61 | mbr_config = deepcopy(base_mbr_config)
62 | mbr_config.metric = "fastchrf-aggregate"
63 | mbr_configs["MBR with fastchrf.aggregate_chrf"] = mbr_config
64 |
65 | for method, mbr_config in mbr_configs.items():
66 |
67 | set_seed(42)
68 | time_start = time.time()
69 | outputs = mt_pipeline(
70 | dataset["test"]["text"],
71 | mbr_config=mbr_config,
72 | generation_config=generation_config,
73 | tokenizer=tokenizer,
74 | batch_size=batch_size,
75 | progress_bar=True
76 | )
77 | translations = []
78 | for batch in tqdm(outputs):
79 | if isinstance(batch, dict):
80 | batch = [batch]
81 | translations += [translation["translation_text"] for translation in batch]
82 | time_end = time.time()
83 |
84 | chrf_score = evaluation_metric_chrf.compute(
85 | predictions=translations,
86 | references=references,
87 | )
88 | comet_score = evaluation_metric_comet.compute(
89 | predictions=translations,
90 | references=references,
91 | sources=dataset["test"]["text"],
92 | gpus=0,
93 | )
94 | results_file.write({
95 | "language_pair": language_pair,
96 | "method": method,
97 | "chrf": chrf_score["score"],
98 | "comet22": comet_score["mean_score"],
99 | "duration": time_end - time_start,
100 | "translations": translations,
101 | })
102 |
103 | # Beam search
104 | model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device)
105 | mt_pipeline.model = model
106 | generation_config = GenerationConfig.from_pretrained(model_name)
107 | generation_config.num_beams = 4
108 |
109 | set_seed(42)
110 | time_start = time.time()
111 | outputs = mt_pipeline(
112 | dataset["test"]["text"],
113 | generation_config=generation_config,
114 | batch_size=batch_size,
115 | )
116 | translations = []
117 | for batch in tqdm(outputs):
118 | if isinstance(batch, dict):
119 | batch = [batch]
120 | translations += [translation["translation_text"] for translation in batch]
121 | time_end = time.time()
122 |
123 | chrf_score = evaluation_metric_chrf.compute(
124 | predictions=translations,
125 | references=references,
126 | )
127 | comet_score = evaluation_metric_comet.compute(
128 | predictions=translations,
129 | references=references,
130 | sources=dataset["test"]["text"],
131 | gpus=0,
132 | )
133 | results_file.write({
134 | "language_pair": language_pair,
135 | "method": f"beam search (beam size {generation_config.num_beams})",
136 | "chrf": chrf_score["score"],
137 | "comet22": comet_score["mean_score"],
138 | "duration": time_end - time_start,
139 | "translations": translations,
140 | })
141 |
142 | results_file.close()
143 |
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/README.md:
--------------------------------------------------------------------------------
1 | This directory uses the [**mbr**](https://github.com/ZurichNLP/mbr) package to reproduce an experiment from the paper [Epsilon Sampling Rocks: Investigating Sampling Strategies for Minimum Bayes Risk Decoding for Machine Translation](https://arxiv.org/abs/2305.09860) (Freitag et al., 2023).
2 |
3 | ## Setup
4 | * Task: Machine translation
5 | * Translation directions: en–de, de–en
6 | * MBR metric: Neural metric (paper: BLEURT, this reproduction: Cometinho)
7 | * Number of samples: 1024
8 | * Various sampling approaches
9 | * Samples and references are the same
10 | * Test set: newstest2021
11 | * Evaluation metric: COMET ([Rei et al., 2020](https://aclanthology.org/2020.emnlp-main.213/))
12 | * Baseline: beam search with beam size 4
13 |
14 | ## Differences to the paper
15 | * The paper used custom models trained without label smoothing, this reproduction uses an open-source model ([Ng et al., WMT 2019](https://aclanthology.org/W19-5333/)).
16 | * The paper used BLEURT ([Sellam et al., 2020](https://aclanthology.org/2020.acl-main.704/)) as a metric, this reproduction uses Cometinho ([Rei et al., 2022](https://aclanthology.org/2022.eamt-1.9/)).
17 |
18 | ## Results
19 |
20 | Comparison between ancestral sampling and epsilon sampling:
21 |
22 | | Paper | Reproduction |
23 | |:---------------------------------------------------------------:|:---:|
24 | | .png) | .png) |
25 | | .png) | .png) |
26 |
27 | Comparison between beam search and various sampling approaches:
28 |
29 | | Paper | Reproduction |
30 | |:---------------------------------------------------------------:|:---:|
31 | | .png) | .png) |
32 | | .png) | .png) |
33 |
34 | Although the models used in this reproduction seem to be less suitable for sampling, especially at higher temperatures, the main comparison between ancestral sampling and epsilon sampling has the same trend as in the paper.
35 |
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (original).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (reproduction).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (original).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (reproduction).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (original).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (reproduction).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (original).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (reproduction).png
--------------------------------------------------------------------------------
/experiments/freitag-et-al-2023-epsilon/run_experiment.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from copy import deepcopy
3 | from pathlib import Path
4 |
5 | import evaluate
6 | import jsonlines
7 | import sacrebleu
8 | import torch
9 | from datasets import load_dataset
10 | from tqdm import tqdm
11 | from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig
12 |
13 | from mbr import MBR, MBRConfig
14 | from mbr.metrics.comet import CometMetricRunner
15 |
16 | set_seed(42)
17 |
18 | language_pair = sys.argv[1]
19 |
20 | models = {
21 | "en-de": "facebook/wmt19-en-de",
22 | "en-zh": None,
23 | "de-en": "facebook/wmt19-de-en",
24 | "zh-en": None,
25 | }
26 |
27 | testsets = {
28 | "en-de": "wmt21/C",
29 | "en-zh": "wmt21/B",
30 | "de-en": "wmt21",
31 | "zh-en": "wmt21",
32 | }
33 |
34 | results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}3.jsonl", "w")
35 |
36 | model_name = models[language_pair]
37 | model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name)
38 | tokenizer = AutoTokenizer.from_pretrained(model_name)
39 | mt_pipeline = pipeline(
40 | "translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1],
41 | model=model,
42 | tokenizer=tokenizer,
43 | device=(0 if torch.cuda.is_available() else -1),
44 | )
45 | evaluation_metric_comet = evaluate.load("comet", "wmt20-comet-da")
46 |
47 | src_path = sacrebleu.get_source_file(testsets[language_pair], language_pair)
48 | ref_path = sacrebleu.get_reference_files(testsets[language_pair], language_pair)[0]
49 | dataset = load_dataset("text", data_files={"test": src_path})
50 | references = Path(ref_path).read_text().splitlines()
51 | assert len(dataset["test"]) == len(references)
52 |
53 | # MBR
54 | mbr_config = MBRConfig()
55 | mbr_config.num_samples = 1024
56 | mbr_config.metric = "comet"
57 | mbr_config.metric_config_name = "eamt22-cometinho-da"
58 | mbr_config.metric_output_field = "mean_score"
59 |
60 | metric_runner = CometMetricRunner(
61 | mbr_config,
62 | tokenizer,
63 | device=mt_pipeline.device,
64 | batch_size_embed=64,
65 | batch_size_estimate=64,
66 | progress_bar=True,
67 | )
68 |
69 | base_generation_config = GenerationConfig.from_pretrained(model_name)
70 | base_generation_config.do_sample = True
71 | base_generation_config.num_beams = 1
72 | base_generation_config.early_stopping = False
73 | generation_configs = {}
74 |
75 | # # MBR – Ancestral (τ=1.0)
76 | # generation_config = deepcopy(base_generation_config)
77 | # generation_config.temperature = 1.0
78 | # generation_configs["mbr ancestral (τ=1.0)"] = generation_config
79 | #
80 | # # MBR – Top-k (k=10, τ=1.0)
81 | # generation_config = deepcopy(base_generation_config)
82 | # generation_config.top_k = 10
83 | # generation_config.temperature = 1.0
84 | # generation_configs["mbr top-k (k=10, τ=1.0)"] = generation_config
85 |
86 | # # MBR – Top-k (k=50, τ=1.0)
87 | # generation_config = deepcopy(base_generation_config)
88 | # generation_config.top_k = 50
89 | # generation_config.temperature = 1.0
90 | # generation_configs["mbr top-k (k=50, τ=1.0)"] = generation_config
91 | #
92 | # # MBR – Nucleus (p=0.9, τ=1.5)
93 | # generation_config = deepcopy(base_generation_config)
94 | # generation_config.top_p = 0.9
95 | # generation_config.temperature = 1.5
96 | # generation_configs["mbr nucleus (p=0.9, τ=1.5)"] = generation_config
97 | #
98 | # MBR – Epsilon (ε=0.02, τ=1.0)
99 | generation_config = deepcopy(base_generation_config)
100 | generation_config.epsilon_cutoff = 0.02
101 | generation_config.temperature = 1.0
102 | generation_configs["mbr epsilon (ε=0.02, τ=1.0)"] = generation_config
103 |
104 | # MBR – Epsilon (ε=0.02, τ=2.0)
105 | generation_config = deepcopy(base_generation_config)
106 | generation_config.epsilon_cutoff = 0.02
107 | generation_config.temperature = 2.0
108 | generation_configs["mbr epsilon (ε=0.02, τ=2.0)"] = generation_config
109 |
110 | for method, generation_config in generation_configs.items():
111 | outputs = mt_pipeline(
112 | dataset["test"]["text"],
113 | mbr_config=mbr_config,
114 | generation_config=generation_config,
115 | tokenizer=tokenizer,
116 | metric_runner=metric_runner,
117 | batch_size=32,
118 | progress_bar=True
119 | )
120 | translations = []
121 | for batch in tqdm(outputs):
122 | if isinstance(batch, dict):
123 | batch = [batch]
124 | translations += [translation["translation_text"] for translation in batch]
125 | comet_score = evaluation_metric_comet.compute(
126 | predictions=translations,
127 | references=references,
128 | sources=dataset["test"]["text"],
129 | )
130 | results_file.write({
131 | "language_pair": language_pair,
132 | "method": method,
133 | "comet20": comet_score["mean_score"],
134 | "translations": translations,
135 | })
136 |
137 | # Beam search
138 | model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device)
139 | mt_pipeline.model = model
140 | generation_config = GenerationConfig.from_pretrained(model_name)
141 | generation_config.num_beams = 4
142 |
143 | outputs = mt_pipeline(
144 | dataset["test"]["text"],
145 | generation_config=generation_config,
146 | batch_size=32,
147 | )
148 | translations = []
149 | for batch in tqdm(outputs):
150 | if isinstance(batch, dict):
151 | batch = [batch]
152 | translations += [translation["translation_text"] for translation in batch]
153 | comet_score = evaluation_metric_comet.compute(
154 | predictions=translations,
155 | references=references,
156 | sources=dataset["test"]["text"],
157 | )
158 | results_file.write({
159 | "language_pair": language_pair,
160 | "method": "beam search",
161 | "comet20": comet_score["mean_score"],
162 | "translations": translations,
163 | })
164 |
165 | results_file.close()
166 |
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/README.md:
--------------------------------------------------------------------------------
1 | This directory uses the [**mbr**](https://github.com/ZurichNLP/mbr) package to reproduce an experiment from the paper [Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation](https://aclanthology.org/2021.acl-long.22) (Müller & Sennrich, ACL-IJCNLP 2021).
2 |
3 | ## Setup
4 | * Task: Machine translation
5 | * Translation directions: dan–epo, aze–eng, bel–rus, deu–fra
6 | * MBR metric: ChrF2 ([Popović, 2015](https://aclanthology.org/W15-3049/))
7 | * Number of samples: 5–100
8 | * Sampling approach: ancestral sampling
9 | * Samples and references are the same
10 | * Test set: Tatoeba ([Tiedemann, 2020](https://aclanthology.org/2020.wmt-1.139/))
11 | * Evaluation metric: ChrF2
12 | * Baseline: beam search with beam size 5
13 |
14 | ## Differences to the paper
15 | * The paper used custom models trained without label smoothing, this reproduction uses open-source models from Opus-MT ([Tiedemann & Thottingal, 2020](https://aclanthology.org/2020.eamt-1.61)).
16 | * The paper reports averages over 2 runs, this reproduction uses a single run.
17 |
18 | ## Results
19 |
20 | | Paper | Reproduction |
21 | |:---------------------------------------------------------------:|:---:|
22 | | .png) | .png) |
23 | | .png) | .png) |
24 | | .png) | .png) |
25 | | .png) | .png) |
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (original).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (reproduction).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (original).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (reproduction).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (original).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (reproduction).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (original).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (original).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (reproduction).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (reproduction).png
--------------------------------------------------------------------------------
/experiments/müller-sennrich-2021-understanding/run_experiment.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | import jsonlines
5 | import sacrebleu
6 | import torch
7 | from datasets import load_dataset
8 | from tqdm import tqdm
9 | from transformers import MarianMTModel, AutoTokenizer, pipeline, set_seed
10 | from transformers.pipelines.base import KeyDataset
11 |
12 | from mbr import MBR, MBRConfig
13 |
14 | set_seed(42)
15 |
16 | language_pair = sys.argv[1]
17 |
18 | opus_mt_models = {
19 | "dan-epo": "Helsinki-NLP/opus-mt-da-eo",
20 | "aze-eng": "Helsinki-NLP/opus-mt-az-en",
21 | "bel-rus": "Helsinki-NLP/opus-mt-tc-big-zle-zle",
22 | "deu-fra": "Helsinki-NLP/opus-mt-de-fr",
23 | }
24 |
25 | language_codes = {
26 | "dan": "da",
27 | "epo": "eo",
28 | "aze": "az",
29 | "eng": "en",
30 | "bel": "be",
31 | "rus": "ru",
32 | "deu": "de",
33 | "fra": "fr",
34 | }
35 |
36 | results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w")
37 |
38 | model_name = opus_mt_models[language_pair]
39 | model = MBR(MarianMTModel).from_pretrained(model_name)
40 | model = model.half()
41 | tokenizer = AutoTokenizer.from_pretrained(model_name)
42 | src_code = language_codes[language_pair.split("-")[0]]
43 | tgt_code = language_codes[language_pair.split("-")[1]]
44 | mt_pipeline = pipeline(
45 | "translation_" + src_code + "_to_" + tgt_code,
46 | model=model,
47 | tokenizer=tokenizer,
48 | device=0 if torch.cuda.is_available() else -1,
49 | )
50 |
51 | dataset = load_dataset("Helsinki-NLP/tatoeba_mt", language_pair=language_pair)
52 | references = dataset["test"]["targetString"]
53 |
54 | # MBR
55 | mbr_config = MBRConfig()
56 | mbr_config.metric = "chrf"
57 | mbr_config.metric_output_field = "score"
58 | batch_size = 64 if "-big-" in model_name else 256
59 |
60 | for num_samples in range(5, 101, 5):
61 | mbr_config.num_samples = num_samples
62 | mbr_config.num_references = num_samples
63 |
64 | outputs = mt_pipeline(
65 | KeyDataset(dataset["test"], "sourceString"),
66 | mbr_config=mbr_config,
67 | tokenizer=tokenizer,
68 | do_sample=True,
69 | num_beams=1,
70 | batch_size=batch_size,
71 | )
72 | translations = []
73 | for batch in tqdm(outputs, total=len(dataset["test"]) // batch_size):
74 | translations += [translation["translation_text"] for translation in batch]
75 | chrf_score = sacrebleu.corpus_chrf(translations, [references])
76 | results_file.write({
77 | "language_pair": language_pair,
78 | "method": "mbr",
79 | "num_samples": num_samples,
80 | "chrf": chrf_score.score,
81 | "translations": translations,
82 | })
83 |
84 | # Beam search
85 | model = MarianMTModel.from_pretrained(model_name).to(mt_pipeline.device)
86 | mt_pipeline.model = model
87 |
88 | outputs = mt_pipeline(
89 | KeyDataset(dataset["test"], "sourceString"),
90 | num_beams=5,
91 | batch_size=32,
92 | )
93 | translations = []
94 | for batch in tqdm(outputs):
95 | translations += [translation["translation_text"] for translation in batch]
96 | chrf_score = sacrebleu.corpus_chrf(translations, [references])
97 | results_file.write({
98 | "language_pair": language_pair,
99 | "method": "beam_search",
100 | "chrf": chrf_score.score,
101 | "translations": translations,
102 | })
103 |
104 | results_file.close()
105 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/README.md:
--------------------------------------------------------------------------------
1 | ## Code for the paper ["Linear-time Minimum Bayes Risk Decoding with Reference Aggregation"](https://arxiv.org/abs/2402.04251)
2 |
3 | - The research code in this directory implements reference aggregation, an efficiency method for MBR that uses aggregate reference representations for faster utility estimation.
4 | - We apply reference aggregation to two metrics: ChrF and COMET.
5 | - Unlike the **mbr** package, the code in this directory is purely research-oriented (= reproducing the tables and figures in our paper) and not optimized for usability.
6 |
7 | ## Installation
8 | - Requires Python >= 3.9 and PyTorch.
9 | - `pip install -r requirements.txt`
10 |
11 | ## Reproducing the experiments
12 |
13 | ### Creating the samples
14 | - Warning: The following code downloads a large translation model from PyTorch Hub (if not already present) and generates 1024 samples per segment, which will take some time.
15 | - Samples will be stored in a JSON lines file in the directory `samples/`.
16 | ```bash
17 | python generate_samples.py --testset wmt21 --language-pair en-de --seed 0
18 | ```
19 |
20 | ### Figure 1: Top-20 accuracy
21 | #### Generating the translations
22 | - Performing this analysis is computationally heavy because we run it for many different values of _s_ (x-axis of Figure 1).
23 | - We run N-by-N MBR, N-by-S MBR and Reference Aggregation in a single script, and all values of _s_, so that the embedding part of COMET only needs to run once.
24 | - The results are stored in a JSON lines file in the directory `validation_output/`. Each line describes the output for one method and one value of _s_.
25 | - In addition, the top translations will be stored in text files (one translation per line) in the `translations/` directory, to allow for easy evaluation.
26 | - The utility metric is either `"chrf"`, `"cometinho"` or `"comet22"`.
27 | ```bash
28 | python validation.py --testset wmt21 --language-pair en-de --seed 0 --utility comet22 --topk 20
29 | ```
30 | #### Calculating accuracy
31 | - After the script has run, the series for Figure 1 (top-20 accuracy) can be printed as follows.
32 | - The method can be either `"n_by_s"` or `"aggregate"`.
33 | ```bash
34 | python plot_accuracy.py --testset wmt21 --language-pair en-de --seed 0 --utility comet22 --topk 20 --method aggregate
35 | ```
36 | - To calculate top-1 accuracy instead:
37 | ```bash
38 | python plot_accuracy.py --testset wmt21 --language-pair en-de --seed 0 --utility comet22 --topk 20 --method aggregate --accuracy-topk 1
39 | ```
40 |
41 | ### Table 1: Test results
42 |
43 | #### Generating the translations
44 | - In the test results table, we compare the translation quality of beam search, epsilon sampling, standard (pairwise) MBR, and reference aggregation. We also experiment with aggregate-to-fine MBR.
45 | - The following scripts create the translations and store them in the `translations/` directory.
46 | ```bash
47 | # Beam search
48 | python baseline_beam_search.py --language-pair en-de --testset wmt22
49 |
50 | # MBR with ChrF metric – standard MBR
51 | python run_mbr.py --method pairwise --testset wmt22 --language-pair en-de --seed 0 --utility chrf
52 | # MBR with ChrF metric – reference aggregation
53 | python run_mbr.py --method aggregate --testset wmt22 --language-pair en-de --seed 0 --utility chrf
54 | # MBR with ChrF metric – aggregate-to-fine MBR
55 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --utility chrf
56 |
57 | # MBR with Comethinho metric – standard MBR
58 | python run_mbr.py --method pairwise --testset wmt22 --language-pair en-de --seed 0 --utility cometinho
59 | # MBR with Cometinho metric – reference aggregation
60 | python run_mbr.py --method aggregate --testset wmt22 --language-pair en-de --seed 0 --utility cometinho
61 | # MBR with Cometinho metric – aggregate-to-fine MBR
62 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --utility cometinho
63 |
64 | # MBR with COMET-22 metric – standard MBR
65 | python run_mbr.py --method pairwise --testset wmt22 --language-pair en-de --seed 0 --utility comet22
66 | # MBR with COMET-22 metric – reference aggregation
67 | python run_mbr.py --method aggregate --testset wmt22 --language-pair en-de --seed 0 --utility comet22
68 | # MBR with COMET-22 metric – aggregate-to-fine MBR
69 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --utility comet22
70 |
71 | # Coarse-to-fine MBR: ChrF to COMET-22
72 | python run_mbr.py --method coarse_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --coarse-utility chrf --utility comet22
73 | # Aggregate-to-fine MBR: Aggregate ChrF to COMET-22
74 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --coarse-utility chrf --utility comet22
75 | ```
76 | - For epsilon sampling, we simply read the JSON lines file created by `generate_samples.py` and extract the first sample for each segment.
77 | ```bash
78 | python baseline_epsilon_sampling.py --testset wmt22 --language-pair en-de --seed 0
79 | ```
80 |
81 | #### Saving the source sequences and references in a text file
82 | - The sequences will be stored in text files in the `translations/` directory
83 | ```bash
84 | python scripts/save_src_and_ref.py --testset wmt22 --language-pair en-de
85 | ```
86 |
87 | #### Evaluating the translations
88 | - Use a tool of your choice (e.g., https://github.com/mjpost/sacrebleu) to perform the evaluation.
89 |
90 |
91 | ## Citation
92 | ```bibtex
93 | @misc{vamvas-sennrich-2024-linear,
94 | title={Linear-time Minimum Bayes Risk Decoding with Reference Aggregation},
95 | author={Jannis Vamvas and Rico Sennrich},
96 | year={2024},
97 | eprint={2402.04251},
98 | archivePrefix={arXiv},
99 | primaryClass={cs.CL}
100 | }
101 | ```
102 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/baseline_beam_search.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from experiments.reference_aggregation.experiment_utils import Testset
5 | from experiments.reference_aggregation.fairseq_utils import load_model
6 |
7 |
8 | def main(testset: str, language_pair: str, beam_size: int = 4, limit_segments: int = None,
9 | out_dir: Path = None) -> Path:
10 | if out_dir is None:
11 | out_dir = Path(__file__).parent
12 |
13 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments)
14 |
15 | model = load_model(language_pair)
16 |
17 | translations_dir = out_dir / "translations"
18 | translations_dir.mkdir(exist_ok=True)
19 | out_path = translations_dir / f"{dataset}.beam{beam_size}.{dataset.tgt_lang}"
20 |
21 | translations = model.translate(dataset.source_sentences, beam=beam_size)
22 | assert len(translations) == len(dataset.source_sentences)
23 |
24 | with open(out_path, "w") as f:
25 | for translation in translations:
26 | f.write(translation + "\n")
27 |
28 | return out_path
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
34 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
35 | parser.add_argument("--beam-size", type=int, default=4)
36 | parser.add_argument('--limit-segments', type=int, default=None,
37 | help='Limit number of segments that are processed (used for testing)')
38 | args = parser.parse_args()
39 |
40 | out_path = main(testset=args.testset, language_pair=args.language_pair, beam_size=args.beam_size,
41 | limit_segments=args.limit_segments, )
42 | assert out_path.exists()
43 | print(f"Saved translations to {out_path}")
44 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/baseline_epsilon_sampling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import jsonlines
5 | from tqdm import tqdm
6 |
7 |
8 | def main(testset: str, language_pair: str, num_samples: int, epsilon_cutoff: float, seed_no: int,
9 | out_dir: Path = None) -> Path:
10 | if out_dir is None:
11 | out_dir = Path(__file__).parent
12 |
13 | samples_dir = out_dir / "samples"
14 | assert samples_dir.exists()
15 | samples_path = samples_dir / f"samples.{testset}.{language_pair}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl"
16 | assert samples_path.exists()
17 |
18 | translations_dir = out_dir / "translations"
19 | translations_dir.mkdir(exist_ok=True)
20 | out_path = translations_dir / f"{testset}.{language_pair}.epsilon{epsilon_cutoff}.seed{seed_no}.{language_pair.split('-')[1]}"
21 |
22 | with jsonlines.open(samples_path) as f_in, open(out_path, "w") as f_out:
23 | for line in tqdm(f_in):
24 | samples = line["samples"]
25 | assert len(samples) == num_samples
26 | f_out.write(samples[0] + "\n")
27 |
28 | return out_path
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
34 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
35 | parser.add_argument('--seed', type=int, choices=range(10), required=True,
36 | help='Index of the random seed in the list of random seeds')
37 | parser.add_argument('--num-samples', type=int, default=1024)
38 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02)
39 | args = parser.parse_args()
40 |
41 | out_path = main(testset=args.testset, language_pair=args.language_pair, num_samples=args.num_samples,
42 | epsilon_cutoff=args.epsilon_cutoff, seed_no=args.seed, )
43 | assert out_path.exists()
44 | print(f"Saved translations to {out_path}")
45 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/experiment_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import List
4 |
5 | import sacrebleu
6 |
7 |
8 | SEEDS = [
9 | 553589,
10 | 456178,
11 | 817304,
12 | 6277,
13 | 792418,
14 | 707983,
15 | 249859,
16 | 272618,
17 | 760402,
18 | 472974,
19 | ]
20 |
21 |
22 | @dataclass
23 | class Testset:
24 | testset: str
25 | language_pair: str
26 | source_sentences: List[str]
27 | references: List[str]
28 |
29 | @property
30 | def src_lang(self):
31 | return self.language_pair.split("-")[0]
32 |
33 | @property
34 | def tgt_lang(self):
35 | return self.language_pair.split("-")[1]
36 |
37 | @classmethod
38 | def from_wmt(cls, wmt: str, language_pair: str, limit_segments: int = None):
39 | assert wmt in {"wmt21", "wmt22"}
40 | src_path = sacrebleu.get_source_file(wmt, language_pair)
41 | ref_path = sacrebleu.get_reference_files(wmt, language_pair)[0]
42 | source_sequences = Path(src_path).read_text().splitlines()
43 | references = Path(ref_path).read_text().splitlines()
44 | assert len(source_sequences) == len(references)
45 | if limit_segments is not None:
46 | source_sequences = source_sequences[:limit_segments]
47 | references = references[:limit_segments]
48 | return cls(testset=wmt, language_pair=language_pair, source_sentences=source_sequences, references=references, )
49 |
50 | def __str__(self):
51 | return f"{self.testset}.{self.language_pair}"
52 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/fairseq_utils.py:
--------------------------------------------------------------------------------
1 | import urllib
2 | from collections import namedtuple
3 | from pathlib import Path
4 | from typing import Union, List
5 |
6 | import torch
7 | from fairseq import hub_utils
8 | from fairseq.data.encoders.fastbpe import fastBPE
9 | from fairseq.hub_utils import GeneratorHubInterface
10 |
11 |
12 | class FairseqTranslationModel:
13 | """
14 | Adapted from https://github.com/ZurichNLP/contrastive-conditioning/blob/master/translation_models/fairseq_models.py
15 | """
16 |
17 | def __init__(self,
18 | name: str,
19 | model: GeneratorHubInterface = None,
20 | model_name_or_path: Union[Path, str] = None,
21 | checkpoint_file: str = "checkpoint_best.pt",
22 | src_bpe_codes: Union[Path, str] = None,
23 | tgt_bpe_codes: Union[Path, str] = None,
24 | **kwargs,
25 | ):
26 | self.name = name
27 | self.model = model or hub_utils.GeneratorHubInterface(**hub_utils.from_pretrained(
28 | model_name_or_path=str(model_name_or_path),
29 | checkpoint_file=checkpoint_file,
30 | **kwargs,
31 | ))
32 | # self.model.args.max_tokens = max_tokens
33 | self.model.eval()
34 | if torch.cuda.is_available():
35 | self.model.cuda()
36 |
37 | # EN-RU systems use separate vocabularies, which is not yet supported by torch hub
38 | bpe_args = namedtuple("bpe_args", ["bpe_codes"])
39 | if src_bpe_codes is not None:
40 | bpe_args_src = bpe_args(bpe_codes=str(src_bpe_codes))
41 | self.src_bpe = fastBPE(bpe_args_src)
42 | else:
43 | self.src_bpe = None
44 | if tgt_bpe_codes is not None:
45 | bpe_args_tgt = bpe_args(bpe_codes=str(tgt_bpe_codes))
46 | self.tgt_bpe = fastBPE(bpe_args_tgt)
47 | else:
48 | self.tgt_bpe = None
49 |
50 | def translate(self, sentences: List[str], beam: int = 5, **kwargs) -> List[str]:
51 | return self.model.translate(sentences, beam, **kwargs)
52 |
53 | def sample(self, sentences: List[str], seed=None, **kwargs) -> List[str]:
54 | return self.model.sample(sentences, sampling=True, seed=seed, **kwargs)
55 |
56 | def __str__(self):
57 | return self.name
58 |
59 |
60 | def load_model(language_pair: str) -> FairseqTranslationModel:
61 | if language_pair in ["de-en", "en-de"]:
62 | hub_interface = torch.hub.load(
63 | repo_or_dir="jvamvas/fairseq:epsilon",
64 | model=f'transformer.wmt19.{language_pair}.single_model',
65 | tokenizer='moses',
66 | bpe='fastbpe',
67 | )
68 | model_name = f"transformer.wmt19.{language_pair}.single_model"
69 | model = FairseqTranslationModel(
70 | name=model_name,
71 | model=hub_interface,
72 | )
73 | elif language_pair in ["en-ru", "ru-en"]:
74 | hub_interface = torch.hub.load(
75 | repo_or_dir="jvamvas/fairseq:epsilon",
76 | model=f'transformer.wmt19.{language_pair}.single_model',
77 | tokenizer='moses',
78 | bpe='fastbpe',
79 | )
80 |
81 | # Need to download correct vocab separately (https://github.com/pytorch/fairseq/issues/2928)
82 | hub_base_dir = Path(torch.hub.get_dir())
83 | correct_en_vocab_path = hub_base_dir / "en24k.fastbpe.code"
84 | correct_ru_vocab_path = hub_base_dir / "ru24k.fastbpe.code"
85 | if not correct_en_vocab_path.exists():
86 | with urllib.request.urlopen("https://dl.fbaipublicfiles.com/fairseq/en24k.fastbpe.code") as response, \
87 | open(correct_en_vocab_path, 'wb') as out_file:
88 | data = response.read()
89 | out_file.write(data)
90 | if not correct_ru_vocab_path.exists():
91 | with urllib.request.urlopen("https://dl.fbaipublicfiles.com/fairseq/ru24k.fastbpe.code") as response, \
92 | open(correct_ru_vocab_path, 'wb') as out_file:
93 | data = response.read()
94 | out_file.write(data)
95 |
96 | evaluator_name = f"transformer.wmt19.{language_pair}.single_model"
97 | model = FairseqTranslationModel(
98 | name=evaluator_name,
99 | model=hub_interface,
100 | src_bpe_codes=correct_en_vocab_path if language_pair == "en-ru" else correct_ru_vocab_path,
101 | tgt_bpe_codes=correct_ru_vocab_path if language_pair == "en-ru" else correct_en_vocab_path,
102 | )
103 | else:
104 | raise NotImplementedError
105 | return model
106 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/generate_samples.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import jsonlines
5 | from tqdm import tqdm
6 |
7 | from experiments.reference_aggregation.experiment_utils import SEEDS, Testset
8 | from experiments.reference_aggregation.fairseq_utils import load_model
9 |
10 |
11 | def main(testset: str, language_pair: str, seed_no: int, num_samples: int, epsilon_cutoff: float,
12 | limit_segments: int = None, out_dir: Path = None) -> Path:
13 | if out_dir is None:
14 | out_dir = Path(__file__).parent
15 |
16 | seed = SEEDS[seed_no]
17 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments)
18 |
19 | model = load_model(language_pair)
20 |
21 | samples_dir = out_dir / "samples"
22 | samples_dir.mkdir(exist_ok=True)
23 | out_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl"
24 |
25 | with jsonlines.open(out_path, "w") as f:
26 | for source_sentence in tqdm(dataset.source_sentences):
27 | f.write({"samples": model.sample(num_samples * [source_sentence], seed=seed,
28 | sampling_epsilon_cutoff=epsilon_cutoff), })
29 |
30 | return out_path
31 |
32 |
33 | if __name__ == '__main__':
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
36 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
37 | parser.add_argument('--seed', type=int, choices=range(10), required=True,
38 | help='Index of the random seed in the list of random seeds')
39 | parser.add_argument('--num-samples', type=int, default=1024)
40 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02)
41 | parser.add_argument('--limit-segments', type=int, default=None,
42 | help='Limit number of segments that are processed (used for testing)')
43 | args = parser.parse_args()
44 |
45 | out_path = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed,
46 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, limit_segments=args.limit_segments, )
47 | assert out_path.exists()
48 | print(f"Saved samples to {out_path}")
49 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/mbr_utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import logging
3 | from collections import namedtuple
4 | from typing import Set, Dict, List, Tuple
5 |
6 | import evaluate
7 | import numpy as np
8 | import torch
9 | from fastchrf import pairwise_chrf, aggregate_chrf
10 |
11 |
12 | class ChrfUtility:
13 |
14 | def __init__(self, char_order: int = 6, beta: float = 2.0, remove_whitespace: bool = True, eps_smoothing: bool = False):
15 | self.char_order = char_order
16 | self.beta = beta
17 | self.remove_whitespace = remove_whitespace
18 | self.eps_smoothing = eps_smoothing
19 |
20 | def rank_samples_n_by_s(self, source_sequence: str, samples: List[str], references: List[str], s: int = None) -> np.ndarray:
21 | """
22 | Returns the indices of the samples sorted by their utility score, in descending order.
23 | :param s: The number of references to subsample from the list of references (default: all references)
24 | """
25 | if s is None:
26 | s = len(references)
27 | assert s <= len(references)
28 | references = references[:s]
29 |
30 | metric_scores = pairwise_chrf(
31 | [samples], [references],
32 | char_order=self.char_order,
33 | beta=self.beta,
34 | remove_whitespace=self.remove_whitespace,
35 | eps_smoothing=self.eps_smoothing,
36 | )[0]
37 | metric_scores = np.array(metric_scores) # num_samples x s
38 |
39 | # Sort the samples by their average score
40 | sample_scores = metric_scores.mean(axis=1)
41 | sample_indices = sample_scores.argsort()[::-1]
42 | return sample_indices
43 |
44 | def rank_samples_aggregate(self, source_sequence: str, samples: List[str], references: List[str], s: int) -> np.ndarray:
45 | """
46 | Returns the indices of the samples sorted by their utility score, in descending order.
47 | :param s: The number of aggregate references
48 | """
49 | assert s <= len(references)
50 |
51 | num_partitions = s
52 | partition_size = len(references) // num_partitions
53 | reference_partitions = [references[i * partition_size:(i + 1) * partition_size] for i in range(num_partitions)]
54 |
55 | metric_scores = aggregate_chrf(
56 | num_partitions * [samples], reference_partitions,
57 | char_order=self.char_order,
58 | beta=self.beta,
59 | remove_whitespace=self.remove_whitespace,
60 | eps_smoothing=self.eps_smoothing,
61 | )
62 | metric_scores = np.array(metric_scores).transpose() # num_samples x s
63 |
64 | # Sort the samples by their average score
65 | sample_scores = metric_scores.mean(axis=1)
66 | sample_indices = sample_scores.argsort()[::-1]
67 | return sample_indices
68 |
69 |
70 | CometInputTriple = namedtuple("CometInputTriple", ["src", "hyp", "ref"])
71 |
72 |
73 | class CometUtility:
74 |
75 | def __init__(self,
76 | model_name: str,
77 | batch_size_embed: int = 1,
78 | batch_size_estimate: int = 1,
79 | ):
80 | self.model_name = model_name
81 | self.batch_size_embed = batch_size_embed
82 | self.batch_size_estimate = batch_size_estimate
83 | self.scorer = evaluate.load("comet", model_name).scorer
84 | if torch.cuda.is_available():
85 | self.scorer = self.scorer.to("cuda:0")
86 | else:
87 | logging.warning("CUDA not available, using CPU")
88 | self.scorer.eval()
89 | self.device = self.scorer.device
90 | self.embeddings: Dict[str, torch.FloatTensor] = {}
91 |
92 | @torch.no_grad()
93 | def compute_features(self, input_sequences: Set[str]):
94 | assert not self.scorer.training
95 | input_sequences = list(input_sequences)
96 | encodings = self.scorer.encoder.prepare_sample(input_sequences).to(self.device)
97 | batches = itertools.zip_longest(range(0, len(input_sequences), self.batch_size_embed),
98 | range(self.batch_size_embed, len(input_sequences), self.batch_size_embed))
99 | for start_idx, end_idx in batches:
100 | embeddings = self.scorer.get_sentence_embedding(
101 | input_ids=encodings["input_ids"][start_idx:end_idx],
102 | attention_mask=encodings["attention_mask"][start_idx:end_idx],
103 | )
104 | for j in range(start_idx, end_idx if end_idx is not None else len(input_sequences)):
105 | embedding = embeddings[j - start_idx]
106 | self.embeddings[input_sequences[j]] = embedding
107 |
108 | def clear_features(self):
109 | self.embeddings = {}
110 |
111 | @torch.no_grad()
112 | def rank_samples_n_by_s(self, source_sequence: str, samples: List[str], references: List[str], s: int = None) -> np.ndarray:
113 | """
114 | Returns the indices of the samples sorted by their utility score, in descending order.
115 | :param s: The number of references to subsample from the list of references (default: all references)
116 | """
117 | if s is None:
118 | s = len(references)
119 | assert s <= len(references)
120 | references = references[:s]
121 | assert not self.scorer.training
122 |
123 | # Collect all unique input triples
124 | input_triples: Set[Tuple[str, str, str]] = set()
125 | for sample in samples:
126 | for reference in references:
127 | input_triples.add(CometInputTriple(src=source_sequence, hyp=sample, ref=reference))
128 | input_triples: List = list(input_triples)
129 |
130 | # Compute scores for all input triples
131 | triple_scores: Dict[CometInputTriple, torch.tensor] = {}
132 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate),
133 | range(self.batch_size_estimate, len(input_triples), self.batch_size_estimate))
134 | for start_idx, end_idx in batches:
135 | batch = input_triples[start_idx:end_idx]
136 | batch_scores = self.scorer.estimate(
137 | src_sentemb=torch.stack([self.embeddings[input.src] for input in batch]),
138 | mt_sentemb=torch.stack([self.embeddings[input.hyp] for input in batch]),
139 | ref_sentemb=torch.stack([self.embeddings[input.ref] for input in batch]),
140 | )
141 | for i in range(start_idx, end_idx if end_idx is not None else len(input_triples)):
142 | triple = batch[i - start_idx]
143 | score = batch_scores.score[i - start_idx]
144 | triple_scores[triple] = score
145 |
146 | # Fill in the metric scores matrix
147 | metric_scores = torch.zeros((len(samples), len(references)))
148 | for i, sample in enumerate(samples):
149 | for j, reference in enumerate(references):
150 | metric_scores[i, j] = triple_scores[CometInputTriple(src=source_sequence, hyp=sample, ref=reference)]
151 |
152 | # Sort the samples by their average score
153 | sample_scores = metric_scores.mean(dim=1)
154 | sample_indices = sample_scores.argsort(descending=True)
155 | return sample_indices.cpu().numpy()
156 |
157 | @torch.no_grad()
158 | def rank_samples_aggregate(self, source_sequence: str, samples: List[str], references: List[str], s: int) -> np.ndarray:
159 | """
160 | Returns the indices of the samples sorted by their utility score, in descending order.
161 | :param s: The number of aggregate referencesq
162 | """
163 | assert s <= len(references)
164 | assert not self.scorer.training
165 |
166 | num_partitions = s
167 | partition_size = len(references) // num_partitions
168 |
169 | # Add aggregate reference embeddings to the embeddings cache
170 | reference_embeddings = torch.stack([self.embeddings[reference] for reference in references])
171 | avg_reference_embeddings = reference_embeddings.view(num_partitions, partition_size, -1).mean(dim=1)
172 | for partition_id in range(num_partitions):
173 | self.embeddings[f"aggregate_{partition_id}"] = avg_reference_embeddings[partition_id]
174 |
175 | # Collect all unique input triples
176 | input_triples: Set[Tuple[str, str, str]] = set()
177 | for sample in samples:
178 | for partition_id in range(s):
179 | input_triples.add(CometInputTriple(src=source_sequence, hyp=sample, ref=f"aggregate_{partition_id}"))
180 | input_triples: List = list(input_triples)
181 |
182 | # Compute scores for all input triples
183 | triple_scores: Dict[CometInputTriple, torch.tensor] = {}
184 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate),
185 | range(self.batch_size_estimate, len(input_triples), self.batch_size_estimate))
186 | for start_idx, end_idx in batches:
187 | batch = input_triples[start_idx:end_idx]
188 | batch_scores = self.scorer.estimate(
189 | src_sentemb=torch.stack([self.embeddings[input.src] for input in batch]),
190 | mt_sentemb=torch.stack([self.embeddings[input.hyp] for input in batch]),
191 | ref_sentemb=torch.stack([self.embeddings[input.ref] for input in batch]),
192 | )
193 | for i in range(start_idx, end_idx if end_idx is not None else len(input_triples)):
194 | triple = batch[i - start_idx]
195 | score = batch_scores.score[i - start_idx]
196 | triple_scores[triple] = score
197 |
198 | # Fill in the metric scores matrix
199 | metric_scores = torch.zeros((len(samples), num_partitions))
200 | for i, sample in enumerate(samples):
201 | for partition_id in range(s):
202 | metric_scores[i, partition_id] = triple_scores[CometInputTriple(src=source_sequence, hyp=sample, ref=f"aggregate_{partition_id}")]
203 |
204 | # Sort the samples by their average score
205 | sample_scores = metric_scores.mean(dim=1)
206 | sample_indices = sample_scores.argsort(descending=True)
207 | return sample_indices.cpu().numpy()
208 |
209 |
210 | def load_utility(utility_name: str):
211 | if utility_name == "chrf":
212 | return ChrfUtility()
213 | elif utility_name.startswith("comet22"):
214 | return CometUtility("Unbabel/wmt22-comet-da", batch_size_embed=128, batch_size_estimate=128)
215 | elif utility_name.startswith("cometinho"):
216 | return CometUtility("eamt22-cometinho-da", batch_size_embed=512, batch_size_estimate=512)
217 | else:
218 | raise ValueError(f"Unknown utility {utility_name}")
219 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/plot_accuracy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from typing import List, Tuple
4 |
5 | import jsonlines
6 |
7 | from experiments.reference_aggregation.experiment_utils import Testset
8 |
9 |
10 | def main(testset: str, language_pair: str, seed_no: int, fine_utility_name: str, topk: int, accuracy_topk: int,
11 | method: str, num_samples: int = 1024, epsilon_cutoff: float = 0.02, coarse_utility_name: str = None,
12 | limit_segments: int = None, out_dir: Path = None) -> List[Tuple[int, float]]:
13 | """
14 | Returns a series of (s, accuracy) tuples, starting with the highest s
15 | """
16 | if out_dir is None:
17 | out_dir = Path(__file__).parent
18 |
19 | if coarse_utility_name is None:
20 | coarse_utility_name = fine_utility_name
21 |
22 | assert topk <= num_samples
23 | assert accuracy_topk <= topk
24 |
25 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments)
26 |
27 | samples_dir = out_dir / "samples"
28 | assert samples_dir.exists()
29 | samples_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl"
30 | assert samples_path.exists()
31 | with jsonlines.open(samples_path) as f:
32 | samples = [line["samples"] for line in f]
33 | samples = [sample[:num_samples] for sample in samples]
34 | if limit_segments is not None:
35 | samples = samples[:limit_segments]
36 |
37 | output_dir = out_dir / "validation_output"
38 | assert output_dir.exists()
39 | fine_output_path = output_dir / f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{fine_utility_name}.top{topk}.jsonl"
40 | with jsonlines.open(fine_output_path) as f:
41 | fine_data = list(f)
42 | coarse_output_path = output_dir / f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{coarse_utility_name}.top{topk}.jsonl"
43 | with jsonlines.open(coarse_output_path) as f:
44 | coarse_data = list(f)
45 |
46 | # Get n-by-n top-1 samples – should not matter which method
47 | n_by_n_lines = [line for line in fine_data if line["s"] == num_samples]
48 | assert len(n_by_n_lines) == 2
49 | for ranking in zip(n_by_n_lines[0]["rankings"], n_by_n_lines[1]["rankings"]):
50 | assert ranking[0] == ranking[1]
51 | n_by_n_rankings = n_by_n_lines[0]["rankings"]
52 | n_by_n_top1_samples = [samples[i][n_by_n_rankings[i][0]].strip() for i in range(len(samples))]
53 |
54 | # Get top-k accuracies for efficiency method
55 | method_lines = [line for line in coarse_data if line["method"] == method]
56 | assert len(method_lines) == len(coarse_data) / 2
57 | s_values = list(sorted([line["s"] for line in method_lines], reverse=True))
58 | accuracies = [] # for each s
59 | for s in s_values:
60 | s_lines = [line for line in method_lines if line["s"] == s]
61 | assert len(s_lines) == 1
62 | s_rankings = s_lines[0]["rankings"]
63 | s_topk_samples = [{samples[i][ranking].strip() for ranking in s_rankings[i][:accuracy_topk]} for i in
64 | range(len(samples))]
65 | s_num_correct = sum([1 if n_by_n_top1_samples[i] in s_topk_samples[i] else 0 for i in range(len(samples))])
66 | s_accuracy = s_num_correct / len(samples)
67 | accuracies.append(s_accuracy)
68 |
69 | # Format: (1,-0.4)(2,-0.6)(4,-0.5)(8,0.1)(16,0.1)(32,0.2)(64,0.1)(128,-0.0)(256,-0.0)
70 | series = [(s, accuracy) for s, accuracy in zip(s_values, accuracies)]
71 | return series
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
77 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
78 | parser.add_argument('--seed', type=int, choices=range(10), required=True,
79 | help='Index of the random seed in the list of random seeds')
80 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True)
81 | parser.add_argument('--coarse-utility', choices=['chrf', 'cometinho', 'comet22'], default=None,
82 | help='Utility used for coarse-grained method (default: same as fine-grained)')
83 | parser.add_argument('--topk', type=int, default=20,
84 | help='Number of top translations that have been saved in the jsonl file')
85 | parser.add_argument('--method', choices=['n_by_s', 'aggregate'], required=True)
86 | parser.add_argument('--num-samples', type=int, default=1024)
87 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02)
88 | parser.add_argument('--accuracy-topk', type=int, default=None,
89 | help='Number of top translations that are used to compute the accuracy (default: same as data-topk)')
90 | parser.add_argument('--limit-segments', type=int, default=None,
91 | help='Limit number of segments that are processed (used for testing)')
92 | args = parser.parse_args()
93 |
94 | if args.coarse_utility is None:
95 | args.coarse_utility = args.utility
96 | if args.accuracy_topk is None:
97 | args.accuracy_topk = args.topk
98 |
99 | series = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed,
100 | fine_utility_name=args.utility, coarse_utility_name=args.coarse_utility, topk=args.topk, method=args.method,
101 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, accuracy_topk=args.accuracy_topk,
102 | limit_segments=args.limit_segments, )
103 |
104 | # Format: (1,-0.4)(2,-0.6)(4,-0.5)(8,0.1)(16,0.1)(32,0.2)(64,0.1)(128,-0.0)(256,-0.0)
105 | series_str = "".join([f"({s},{accuracy:.5f})" for s, accuracy in series])
106 | print(
107 | f"Testset: {args.testset}, language pair: {args.language_pair}, seed: {args.seed}, fine utility: {args.utility}, coarse utility: {args.coarse_utility}, topk: {args.topk}, method: {args.method}")
108 | print(f"Top-{args.accuracy_topk} accuracy:")
109 | print(series_str)
110 | print()
111 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/jvamvas/fairseq.git@epsilon # Adds epsilon sampling
2 | scikit_learn==1.3.2
3 | sacremoses==0.1.1
4 | fastBPE==0.1.0
5 | requests
6 | jsonlines
7 | unbabel-comet==2.1.1
8 | evaluate==0.4.1
9 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/results/chrf.log:
--------------------------------------------------------------------------------
1 | en-de
2 | baselines
3 | 61.4
4 | 55.1
5 | chrf
6 | 62.0
7 | 62.0
8 | 62.0
9 | cometinho
10 | 59.8
11 | 60.2
12 | 59.8
13 | comet22
14 | 60.0
15 | 60.3
16 | 60.0
17 | coarse-to-fine
18 | 61.8
19 | 61.8
20 |
21 | de-en
22 | baselines
23 | 56.3
24 | 50.7
25 | chrf
26 | 57.5
27 | 57.6
28 | 57.6
29 | cometinho
30 | 55.5
31 | 55.5
32 | 55.5
33 | comet22
34 | 55.3
35 | 55.6
36 | 55.4
37 | coarse-to-fine
38 | 57.0
39 | 57.1
40 |
41 | en-ru
42 | baselines
43 | 52.3
44 | 46.8
45 | chrf
46 | 54.0
47 | 53.9
48 | 54.0
49 | cometinho
50 | 51.4
51 | 51.9
52 | 51.4
53 | comet22
54 | 51.1
55 | 51.7
56 | 51.3
57 | coarse-to-fine
58 | 53.6
59 | 53.6
60 |
61 | ru-en
62 | baselines
63 | 64.4
64 | 57.6
65 | chrf
66 | 65.5
67 | 65.5
68 | 65.5
69 | cometinho
70 | 63.4
71 | 63.4
72 | 63.4
73 | comet22
74 | 62.7
75 | 63.3
76 | 62.8
77 | coarse-to-fine
78 | 64.9
79 | 64.9
80 |
81 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/results/comet-xl.log:
--------------------------------------------------------------------------------
1 | en-de
2 | baselines
3 | wmt22.en-de.beam4.de score: 0.9599
4 | wmt22.en-de.epsilon0.02.seed0.de score: 0.9426
5 | chrf
6 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.chrf.de score: 0.9570
7 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.chrf.de score: 0.9569
8 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.de score: 0.9571
9 | cometinho
10 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.cometinho.de score: 0.9599
11 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.cometinho.de score: 0.9592
12 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.de score: 0.9599
13 | comet22
14 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.comet22.de score: 0.9656
15 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.comet22.de score: 0.9643
16 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.de score: 0.9656
17 | coarse-to-fine
18 | mbr.wmt22.en-de.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.9626
19 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.9625
20 |
21 | de-en
22 | baselines
23 | wmt22.de-en.beam4.en score: 0.9411
24 | wmt22.de-en.epsilon0.02.seed0.en score: 0.9163
25 | chrf
26 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.9397
27 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.9397
28 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.9397
29 | cometinho
30 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.9434
31 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.9420
32 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.9432
33 | comet22
34 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.9510
35 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.9491
36 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.9505
37 | coarse-to-fine
38 | mbr.wmt22.de-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9471
39 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9473
40 |
41 | en-ru
42 | baselines
43 | wmt22.en-ru.beam4.ru score: 0.8575
44 | wmt22.en-ru.epsilon0.02.seed0.ru score: 0.8157
45 | chrf
46 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.chrf.ru score: 0.8474
47 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.chrf.ru score: 0.8469
48 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.ru score: 0.8476
49 | cometinho
50 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8647
51 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8622
52 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8648
53 | comet22
54 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.comet22.ru score: 0.8914
55 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.comet22.ru score: 0.8867
56 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.ru score: 0.8915
57 | coarse-to-fine
58 | mbr.wmt22.en-ru.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8740
59 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8744
60 |
61 | ru-en
62 | baselines
63 | wmt22.ru-en.beam4.en score: 0.9290
64 | wmt22.ru-en.epsilon0.02.seed0.en score: 0.8996
65 | chrf
66 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.9269
67 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.9263
68 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.9270
69 | cometinho
70 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.9318
71 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.9307
72 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.9318
73 | comet22
74 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.9393
75 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.9374
76 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.9391
77 | coarse-to-fine
78 | mbr.wmt22.ru-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9351
79 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9351
80 |
81 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/results/comet22.log:
--------------------------------------------------------------------------------
1 | en-de
2 | baselines
3 | wmt22.en-de.beam4.de score: 0.8589
4 | wmt22.en-de.epsilon0.02.seed0.de score: 0.8343
5 | chrf
6 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.chrf.de score: 0.8582
7 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.chrf.de score: 0.8586
8 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.de score: 0.8586
9 | cometinho
10 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.cometinho.de score: 0.8636
11 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.cometinho.de score: 0.8632
12 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.de score: 0.8636
13 | comet22
14 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.comet22.de score: 0.8840
15 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.comet22.de score: 0.8808
16 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.de score: 0.8829
17 | coarse-to-fine
18 | mbr.wmt22.en-de.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.8704
19 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.8709
20 |
21 | de-en
22 | baselines
23 | wmt22.de-en.beam4.en score: 0.8391
24 | wmt22.de-en.epsilon0.02.seed0.en score: 0.8159
25 | chrf
26 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.8410
27 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.8413
28 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.8412
29 | cometinho
30 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.8444
31 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.8436
32 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.8442
33 | comet22
34 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.8600
35 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.8572
36 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.8591
37 | coarse-to-fine
38 | mbr.wmt22.de-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8497
39 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8500
40 |
41 | en-ru
42 | baselines
43 | wmt22.en-ru.beam4.ru score: 0.8288
44 | wmt22.en-ru.epsilon0.02.seed0.ru score: 0.8072
45 | chrf
46 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.chrf.ru score: 0.8354
47 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.chrf.ru score: 0.8345
48 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.ru score: 0.8352
49 | cometinho
50 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8464
51 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8440
52 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8456
53 | comet22
54 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.comet22.ru score: 0.8782
55 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.comet22.ru score: 0.8739
56 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.ru score: 0.8773
57 | coarse-to-fine
58 | mbr.wmt22.en-ru.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8570
59 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8573
60 |
61 | ru-en
62 | baselines
63 | wmt22.ru-en.beam4.en score: 0.8434
64 | wmt22.ru-en.epsilon0.02.seed0.en score: 0.8175
65 | chrf
66 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.8440
67 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.8437
68 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.8437
69 | cometinho
70 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.8488
71 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.8482
72 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.8487
73 | comet22
74 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.8623
75 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.8601
76 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.8617
77 | coarse-to-fine
78 | mbr.wmt22.ru-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8529
79 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8526
80 |
81 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/results/cometinho.log:
--------------------------------------------------------------------------------
1 | en-de
2 | baselines
3 | wmt22.en-de.beam4.de score: 0.5536
4 | wmt22.en-de.epsilon0.02.seed0.de score: 0.4571
5 | chrf
6 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.chrf.de score: 0.5560
7 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.chrf.de score: 0.5581
8 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.de score: 0.5584
9 | cometinho
10 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.cometinho.de score: 0.6131
11 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.cometinho.de score: 0.6110
12 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.de score: 0.6138
13 | comet22
14 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.comet22.de score: 0.5791
15 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.comet22.de score: 0.5783
16 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.de score: 0.5796
17 | coarse-to-fine
18 | mbr.wmt22.en-de.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.5709
19 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.5729
20 |
21 | de-en
22 | baselines
23 | wmt22.de-en.beam4.en score: 0.5280
24 | wmt22.de-en.epsilon0.02.seed0.en score: 0.4226
25 | chrf
26 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.5491
27 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.5500
28 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.5491
29 | cometinho
30 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.5933
31 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.5878
32 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.5916
33 | comet22
34 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.5613
35 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.5621
36 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.5624
37 | coarse-to-fine
38 | mbr.wmt22.de-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.5611
39 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.5619
40 |
41 | en-ru
42 | baselines
43 | wmt22.en-ru.beam4.ru score: 0.4989
44 | wmt22.en-ru.epsilon0.02.seed0.ru score: 0.4081
45 | chrf
46 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.chrf.ru score: 0.5538
47 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.chrf.ru score: 0.5521
48 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.ru score: 0.5541
49 | cometinho
50 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.cometinho.ru score: 0.6693
51 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.cometinho.ru score: 0.6575
52 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.ru score: 0.6669
53 | comet22
54 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.comet22.ru score: 0.6046
55 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.comet22.ru score: 0.5991
56 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.ru score: 0.6030
57 | coarse-to-fine
58 | mbr.wmt22.en-ru.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.5861
59 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.5878
60 |
61 | ru-en
62 | baselines
63 | wmt22.ru-en.beam4.en score: 0.6586
64 | wmt22.ru-en.epsilon0.02.seed0.en score: 0.5255
65 | chrf
66 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.6732
67 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.6695
68 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.6712
69 | cometinho
70 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.7288
71 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.7250
72 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.7286
73 | comet22
74 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.6856
75 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.6910
76 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.6871
77 | coarse-to-fine
78 | mbr.wmt22.ru-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.6870
79 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.6846
80 |
81 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/run_mbr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from pathlib import Path
4 | from typing import List, Optional
5 |
6 | import jsonlines
7 | from tqdm import tqdm
8 |
9 | from experiments.reference_aggregation.experiment_utils import Testset
10 | from experiments.reference_aggregation.mbr_utils import load_utility
11 |
12 |
13 | def main(method: str, topk: Optional[int], testset: str, language_pair: str, seed_no: int, fine_utility_name: str,
14 | num_samples: int = 1024, epsilon_cutoff: float = 0.02, coarse_utility_name: str = None,
15 | limit_segments: int = None, log_time: bool = False, out_dir: Path = None) -> Path:
16 | if out_dir is None:
17 | out_dir = Path(__file__).parent
18 |
19 | if coarse_utility_name is None:
20 | coarse_utility_name = fine_utility_name
21 |
22 | if method in {'aggregate_to_fine', 'coarse_to_fine'}:
23 | assert topk <= num_samples
24 |
25 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments)
26 |
27 | samples_dir = out_dir / "samples"
28 | assert samples_dir.exists()
29 | samples_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl"
30 | assert samples_path.exists()
31 | with jsonlines.open(samples_path) as f:
32 | samples = [line["samples"] for line in f]
33 | samples = [sample[:num_samples] for sample in samples]
34 | if limit_segments is not None:
35 | samples = samples[:limit_segments]
36 |
37 | assert len(samples) == len(dataset.source_sentences)
38 | assert all(len(sample) == num_samples for sample in samples)
39 |
40 | references = samples
41 |
42 | utility = load_utility(fine_utility_name)
43 | if coarse_utility_name == fine_utility_name:
44 | coarse_utility = utility
45 | else:
46 | coarse_utility = load_utility(coarse_utility_name)
47 |
48 | translations: List[str] = []
49 |
50 | if log_time:
51 | start_time = time.time()
52 |
53 | for i in tqdm(list(range(len(dataset.source_sentences))), desc="segments"):
54 |
55 | # For COMET: compute embeddings
56 | if hasattr(coarse_utility, "compute_features"):
57 | coarse_utility.clear_features()
58 | input_sequences = {dataset.source_sentences[i]} | set(samples[i]) | set(references[i])
59 | coarse_utility.compute_features(input_sequences)
60 |
61 | if method == 'pairwise':
62 | n_by_n_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], samples[i], references[i],
63 | s=num_samples)
64 | translation = samples[i][n_by_n_ranking[0]]
65 | elif method == 'aggregate':
66 | aggregate_ranking = utility.rank_samples_aggregate(dataset.source_sentences[i], samples[i], references[i],
67 | s=1)
68 | translation = samples[i][aggregate_ranking[0]]
69 | elif method == 'aggregate_to_fine':
70 | aggregate_ranking = coarse_utility.rank_samples_aggregate(dataset.source_sentences[i], samples[i],
71 | references[i], s=1)
72 | topk_samples = [samples[i][aggregate_ranking[j]] for j in range(topk)]
73 |
74 | if fine_utility_name != coarse_utility_name and hasattr(utility, "compute_features"):
75 | utility.clear_features()
76 | input_sequences = {dataset.source_sentences[i]} | set(topk_samples) | set(references[i])
77 | utility.compute_features(input_sequences)
78 |
79 | fine_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], topk_samples, references[i],
80 | s=num_samples)
81 | translation = topk_samples[fine_ranking[0]]
82 | elif method == 'coarse_to_fine':
83 | coarse_ranking = coarse_utility.rank_samples_n_by_s(dataset.source_sentences[i], samples[i], references[i],
84 | s=num_samples)
85 | topk_samples = [samples[i][coarse_ranking[j]] for j in range(topk)]
86 |
87 | if fine_utility_name != coarse_utility_name and hasattr(utility, "compute_features"):
88 | utility.clear_features()
89 | input_sequences = {dataset.source_sentences[i]} | set(topk_samples) | set(references[i])
90 | utility.compute_features(input_sequences)
91 |
92 | fine_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], topk_samples, references[i],
93 | s=num_samples)
94 | translation = topk_samples[fine_ranking[0]]
95 | else:
96 | raise ValueError(f"Unknown method: {method}")
97 | translations.append(translation)
98 |
99 | if log_time:
100 | print(f"Average time per segment: {(time.time() - start_time) / len(dataset.source_sentences):.5f} seconds")
101 |
102 | assert len(translations) == len(dataset.source_sentences)
103 |
104 | translations_dir = out_dir / "translations"
105 | translations_dir.mkdir(exist_ok=True)
106 | out_path = translations_dir / f"mbr.{dataset}.{method}{'.top' + str(topk) if method in {'aggregate_to_fine', 'coarse_to_fine'} else ''}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{coarse_utility_name + '-to-' if coarse_utility_name != fine_utility_name else ''}{fine_utility_name}.{dataset.tgt_lang}"
107 | with open(out_path, "w") as f:
108 | for translation in translations:
109 | f.write(translation + "\n")
110 |
111 | return out_path
112 |
113 |
114 | if __name__ == '__main__':
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument('--method', choices=['pairwise', 'aggregate', 'aggregate_to_fine', 'coarse_to_fine'],
117 | required=True)
118 | parser.add_argument('--topk', type=int, default=20,
119 | help='Number of samples to prune to in aggregate_to_fine method')
120 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
121 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
122 | parser.add_argument('--seed', type=int, choices=range(10), required=True,
123 | help='Index of the random seed in the list of random seeds')
124 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True)
125 | parser.add_argument('--coarse-utility', choices=['chrf', 'cometinho', 'comet22'], default=None,
126 | help='Utility used for coarse-grained method (default: same as fine-grained)')
127 | parser.add_argument('--num-samples', type=int, default=1024)
128 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02)
129 | parser.add_argument('--limit-segments', type=int, default=None,
130 | help='Limit number of segments that are processed (used for testing)')
131 | parser.add_argument('--log-time', action='store_true',
132 | help='Print average wall-clock time per segment (used for benchmarking)')
133 | args = parser.parse_args()
134 |
135 | if args.coarse_utility is None:
136 | args.coarse_utility = args.utility
137 |
138 | out_path = main(method=args.method, topk=args.topk, testset=args.testset, language_pair=args.language_pair,
139 | seed_no=args.seed, fine_utility_name=args.utility, coarse_utility_name=args.coarse_utility,
140 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, limit_segments=args.limit_segments,
141 | log_time=args.log_time, )
142 | assert out_path.exists()
143 | print(f"Saved translations to {out_path}")
144 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/scripts/benchmark_time.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 |
3 | num_segments_per_lp=32
4 |
5 | for lp in en-de de-en en-ru ru-en; do
6 | echo $lp
7 |
8 | # MBR with ChrF metric – standard MBR
9 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method pairwise --testset wmt22 --language-pair $lp --seed 0 --utility chrf --limit-segments $num_segments_per_lp --log-time
10 | # MBR with ChrF metric – reference aggregation
11 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate --testset wmt22 --language-pair $lp --seed 0 --utility chrf --limit-segments $num_segments_per_lp --log-time
12 | # MBR with ChrF metric – aggregate-to-fine MBR
13 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --utility chrf --limit-segments $num_segments_per_lp --log-time
14 |
15 | # MBR with Comethinho metric – standard MBR
16 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method pairwise --testset wmt22 --language-pair $lp --seed 0 --utility cometinho --limit-segments $num_segments_per_lp --log-time
17 | # MBR with Cometinho metric – reference aggregation
18 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate --testset wmt22 --language-pair $lp --seed 0 --utility cometinho --limit-segments $num_segments_per_lp --log-time
19 | # MBR with Cometinho metric – aggregate-to-fine MBR
20 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --utility cometinho --limit-segments $num_segments_per_lp --log-time
21 |
22 | # MBR with COMET-22 metric – standard MBR
23 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method pairwise --testset wmt22 --language-pair $lp --seed 0 --utility comet22 --limit-segments $num_segments_per_lp --log-time
24 | # MBR with COMET-22 metric – reference aggregation
25 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate --testset wmt22 --language-pair $lp --seed 0 --utility comet22 --limit-segments $num_segments_per_lp --log-time
26 | # MBR with COMET-22 metric – aggregate-to-fine MBR
27 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --utility comet22 --limit-segments $num_segments_per_lp --log-time
28 |
29 | # Coarse-to-fine MBR: ChrF to COMET-22
30 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method coarse_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --coarse-utility chrf --utility comet22 --limit-segments $num_segments_per_lp --log-time
31 | # Aggregate-to-fine MBR: Aggregate ChrF to COMET-22
32 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --coarse-utility chrf --utility comet22 --limit-segments $num_segments_per_lp --log-time
33 |
34 | echo
35 |
36 | done
37 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/scripts/evaluate-chrf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd translations
4 |
5 | declare -a language_pairs=("en-de" "de-en" "en-ru" "ru-en")
6 |
7 | for lp in "${language_pairs[@]}"
8 | do
9 | echo $lp
10 |
11 | IFS='-' read -ra ADDR <<< "$lp"
12 | src=${ADDR[0]}
13 | tgt=${ADDR[1]}
14 |
15 | echo "baselines"
16 | sacrebleu wmt22.${lp}.ref.${tgt} -i wmt22.${lp}.beam4.${tgt} -m chrf -b
17 | sacrebleu wmt22.${lp}.ref.${tgt} -i wmt22.${lp}.epsilon0.02.seed0.${tgt} -m chrf -b
18 |
19 | echo "chrf"
20 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.chrf.${tgt} -m chrf -b
21 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.chrf.${tgt} -m chrf -b
22 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.${tgt} -m chrf -b
23 |
24 | echo "cometinho"
25 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.cometinho.${tgt} -m chrf -b
26 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.cometinho.${tgt} -m chrf -b
27 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.${tgt} -m chrf -b
28 |
29 | echo "comet22"
30 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.comet22.${tgt} -m chrf -b
31 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.comet22.${tgt} -m chrf -b
32 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.${tgt} -m chrf -b
33 |
34 | echo "coarse-to-fine"
35 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} -m chrf -b
36 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} -m chrf -b
37 |
38 | echo
39 | done
40 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/scripts/evaluate-comet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model=$1
4 | if [ "$model" != "Unbabel/wmt22-comet-da" ] && [ "$model" != "Unbabel/eamt22-cometinho-da" ] && [ "$model" != "Unbabel/XCOMET-XL" ]; then
5 | echo "Invalid model. Please choose one of Unbabel/wmt22-comet-da, Unbabel/eamt22-cometinho-da, Unbabel/XCOMET-XL"
6 | exit 1
7 | fi
8 |
9 | cd translations
10 |
11 | declare -a language_pairs=("en-de" "de-en" "en-ru" "ru-en")
12 |
13 | for lp in "${language_pairs[@]}"
14 | do
15 | echo $lp
16 |
17 | IFS='-' read -ra ADDR <<< "$lp"
18 | src=${ADDR[0]}
19 | tgt=${ADDR[1]}
20 |
21 | echo "baselines"
22 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t wmt22.${lp}.beam4.${tgt} --only_system --model $1
23 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t wmt22.${lp}.epsilon0.02.seed0.${tgt} --only_system --model $1
24 |
25 | echo "chrf"
26 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.chrf.${tgt} --only_system --model $1
27 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.chrf.${tgt} --only_system --model $1
28 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.${tgt} --only_system --model $1
29 |
30 | echo "cometinho"
31 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.cometinho.${tgt} --only_system --model $1
32 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.cometinho.${tgt} --only_system --model $1
33 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.${tgt} --only_system --model $1
34 |
35 | echo "comet22"
36 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.comet22.${tgt} --only_system --model $1
37 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.comet22.${tgt} --only_system --model $1
38 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.${tgt} --only_system --model $1
39 |
40 | echo "coarse-to-fine"
41 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} --only_system --model $1
42 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} --only_system --model $1
43 |
44 | echo
45 | done
46 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/scripts/plot_accuracy_reverse.py:
--------------------------------------------------------------------------------
1 | """
2 | Compared to plot_accuracy.py, makes plotting easier by reversing the labels (1024->1 etc.)
3 | """
4 | import argparse
5 |
6 | from experiments.reference_aggregation.plot_accuracy import main
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
10 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
11 | parser.add_argument('--seed', type=int, choices=range(10), required=True,
12 | help='Index of the random seed in the list of random seeds')
13 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True)
14 | parser.add_argument('--coarse-utility', choices=['chrf', 'cometinho', 'comet22'], default=None,
15 | help='Utility used for coarse-grained method (default: same as fine-grained)')
16 | parser.add_argument('--topk', type=int, default=20,
17 | help='Number of top translations that have been saved in the jsonl file')
18 | parser.add_argument('--method', choices=['n_by_s', 'aggregate'], required=True)
19 | parser.add_argument('--num-samples', type=int, default=1024)
20 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02)
21 | parser.add_argument('--accuracy-topk', type=int, default=None,
22 | help='Number of top translations that are used to compute the accuracy (default: same as data-topk)')
23 | parser.add_argument('--limit-segments', type=int, default=None,
24 | help='Limit number of segments that are processed (used for testing)')
25 | args = parser.parse_args()
26 |
27 | if args.coarse_utility is None:
28 | args.coarse_utility = args.utility
29 | if args.accuracy_topk is None:
30 | args.accuracy_topk = args.topk
31 |
32 | series = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed, fine_utility_name=args.utility,
33 | coarse_utility_name=args.coarse_utility, topk=args.topk, method=args.method, num_samples=args.num_samples,
34 | epsilon_cutoff=args.epsilon_cutoff, accuracy_topk=args.accuracy_topk, limit_segments=args.limit_segments, )
35 |
36 | s_values = [s for s, _ in series]
37 | reversed_s_values = list(reversed(s_values))
38 | series_str = "".join(
39 | [f"({s},{accuracy:.5f})" for s, accuracy in zip(reversed_s_values, [accuracy for _, accuracy in series])])
40 | print(
41 | f"Testset: {args.testset}, language pair: {args.language_pair}, seed: {args.seed}, fine utility: {args.utility}, coarse utility: {args.coarse_utility}, topk: {args.topk}, method: {args.method}")
42 | print(f"Top-{args.accuracy_topk} accuracy:")
43 | print(series_str)
44 | print()
45 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/scripts/print_data_stats_table.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import jsonlines
4 |
5 | header = "\\begin{tabularx}{\\textwidth}{Xrrr}\n\\toprule\n"
6 | header += "& \\# Segments & \\# Samples per segment & \\# Unique samples per segment \\\\\n\\midrule\n"
7 | footer = "\\bottomrule\n\\end{tabularx}"
8 |
9 | samples_dir = Path(__file__).parent.parent / "samples"
10 |
11 | body = ""
12 | body += "\\textit{newstest21} & & & \\\\\n"
13 | for lang_pair in ["en-de", "de-en", "en-ru", "ru-en"]:
14 | path = samples_dir / f"samples.wmt21.{lang_pair}.n1024.epsilon0.02.seed0.jsonl"
15 | assert path.exists(), f"Path {path} does not exist"
16 | with jsonlines.open(path) as reader:
17 | data = list(reader)
18 | num_segments = len(data)
19 | num_samples = len(data[0]["samples"])
20 | avg_num_unique_samples = sum(
21 | [len(set([sample for sample in segment["samples"]])) for segment in data]) / num_segments
22 | body += "\\textsc{" + lang_pair.replace('-', '–') + "} & " + str(num_segments) + " & " + str(
23 | num_samples) + " & " + "{:.1f}".format(avg_num_unique_samples) + " \\\\\n"
24 | body += "\\addlinespace\n"
25 | body += "\\textit{newstest22} & & & \\\\\n"
26 | for lang_pair in ["en-de", "de-en", "en-ru", "ru-en"]:
27 | path = samples_dir / f"samples.wmt22.{lang_pair}.n1024.epsilon0.02.seed0.jsonl"
28 | assert path.exists(), f"Path {path} does not exist"
29 | with jsonlines.open(path) as reader:
30 | data = list(reader)
31 | num_segments = len(data)
32 | num_samples = len(data[0]["samples"])
33 | avg_num_unique_samples = sum(
34 | [len(set([sample for sample in segment["samples"]])) for segment in data]) / num_segments
35 | body += "\\textsc{" + lang_pair.replace('-', '–') + "} & " + str(num_segments) + " & " + str(
36 | num_samples) + " & " + "{:.1f}".format(avg_num_unique_samples) + " \\\\\n"
37 |
38 | print(header + body + footer)
39 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/scripts/save_src_and_ref.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from experiments.reference_aggregation.experiment_utils import Testset
5 |
6 |
7 | def main(testset: str, language_pair: str, limit_segments: int = None, out_dir: Path = None) -> (Path, Path):
8 | if out_dir is None:
9 | out_dir = Path(__file__).parent
10 |
11 | translations_dir = out_dir / "translations"
12 | translations_dir.mkdir(exist_ok=True)
13 |
14 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments)
15 |
16 | src_out_path = translations_dir / f"{dataset}.src.{dataset.src_lang}"
17 |
18 | with open(src_out_path, "w") as f:
19 | for src in dataset.source_sentences:
20 | f.write(src + "\n")
21 |
22 | ref_out_path = translations_dir / f"{dataset}.ref.{dataset.tgt_lang}"
23 | with open(ref_out_path, "w") as f:
24 | for ref in dataset.references:
25 | f.write(ref + "\n")
26 |
27 | return src_out_path, ref_out_path
28 |
29 |
30 | if __name__ == '__main__':
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
33 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
34 | parser.add_argument('--limit-segments', type=int, default=None,
35 | help='Limit number of segments that are processed (used for testing)')
36 | args = parser.parse_args()
37 |
38 | src_path, ref_path = main(testset=args.testset, language_pair=args.language_pair,
39 | limit_segments=args.limit_segments, )
40 | print(f"Source sentences saved to {src_path}")
41 | print(f"References saved to {ref_path}")
42 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/reference_aggregation/tests/__init__.py
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_beam_search.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest import TestCase
3 |
4 |
5 | class BeamSearchTestCase(TestCase):
6 |
7 | def setUp(self):
8 | self.testset = "wmt21"
9 | self.language_pair = "en-de"
10 | self.test_dir = Path(__file__).parent / "out"
11 | self.test_dir.mkdir(exist_ok=True)
12 |
13 | def test_beam_search(self):
14 | from experiments.reference_aggregation.baseline_beam_search import main
15 | out_path = main(self.testset, self.language_pair, limit_segments=4, out_dir=self.test_dir)
16 | self.assertTrue(out_path.exists())
17 | self.assertIn(self.test_dir, out_path.parents)
18 | self.assertTrue(out_path.name.endswith(".de"))
19 | translations = out_path.read_text().splitlines()
20 | self.assertEqual(len(translations), 4)
21 | print(translations[0])
22 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_chrf.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from experiments.reference_aggregation.mbr_utils import ChrfUtility
4 |
5 |
6 | class CometTestCase(TestCase):
7 |
8 | def setUp(self):
9 | self.chrf = ChrfUtility()
10 |
11 | def test_rank_samples_n_by_n(self):
12 | source_sequence = "This is a sample sentence"
13 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
14 | "Dies ist ein Test."]
15 | references = samples
16 |
17 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references, s=4)
18 | self.assertEqual(len(samples), len(indices))
19 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
20 |
21 | # Test sample order invariance
22 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=4)
23 | self.assertListEqual([2, 3, 0, 1], indices.tolist())
24 |
25 | # Test reference order invariance
26 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references[::-1], s=4)
27 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
28 |
29 | def test_rank_samples_n_by_1(self):
30 | source_sequence = "This is a sample sentence"
31 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
32 | "Dies ist ein Test."]
33 | references = samples
34 |
35 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references, s=1)
36 | self.assertEqual(0, indices[0]) # Perfect match with itself
37 | self.assertListEqual([0, 1, 3, 2], indices.tolist())
38 |
39 | # Test sample order invariance
40 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=1)
41 | self.assertListEqual([3, 2, 0, 1], indices.tolist())
42 |
43 | def test_rank_samples_aggregate(self):
44 | source_sequence = "This is a sample sentence"
45 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
46 | "Dies ist ein Test."]
47 | references = samples
48 |
49 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references, s=1)
50 | self.assertEqual(len(samples), len(indices))
51 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
52 |
53 | # Test sample order invariance
54 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples[::-1], references, s=1)
55 | self.assertListEqual([2, 3, 0, 1], indices.tolist())
56 |
57 | # Test reference order invariance
58 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references[::-1], s=1)
59 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
60 |
61 | def test_rank_samples_aggregate_partial(self):
62 | source_sequence = "This is a sample sentence"
63 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
64 | "Dies ist ein Test."]
65 | references = samples
66 |
67 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references, s=2)
68 | self.assertEqual(len(samples), len(indices))
69 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
70 |
71 | # Test sample order invariance
72 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples[::-1], references, s=2)
73 | self.assertListEqual([2, 3, 0, 1], indices.tolist())
74 |
75 | # Test (partial) reference order invariance: change order of references within partitions
76 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples,
77 | references[:2][::-1] + references[2:][::-1], s=2)
78 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
79 |
80 | def test_rank_samples_disaggregated_is_equivalent_to_n_by_n(self):
81 | source_sequence = "This is a sample sentence"
82 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
83 | "Dies ist ein Test."]
84 | references = samples
85 |
86 | n_by_n_indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references, s=4)
87 | aggregate_indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references, s=4)
88 | self.assertListEqual(n_by_n_indices.tolist(), aggregate_indices.tolist())
89 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_comet.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from experiments.reference_aggregation.mbr_utils import CometUtility
4 |
5 |
6 | class CometTestCase(TestCase):
7 |
8 | def setUp(self):
9 | self.comet = CometUtility("eamt22-cometinho-da")
10 |
11 | def test_compute_features(self):
12 | self.assertEqual(0, len(self.comet.embeddings))
13 | self.comet.compute_features({"This is a test.", "Dies ist ein Test."})
14 | self.assertEqual(2, len(self.comet.embeddings))
15 | self.comet.clear_features()
16 | self.assertEqual(0, len(self.comet.embeddings))
17 |
18 | def test_rank_samples_n_by_n(self):
19 | source_sequence = "This is a sample sentence"
20 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
21 | "Dies ist ein Test."]
22 | references = samples
23 | self.comet.compute_features({source_sequence} | set(samples) | set(references))
24 |
25 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references, s=4)
26 | self.assertEqual(len(samples), len(indices))
27 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
28 |
29 | # Test sample order invariance
30 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=4)
31 | self.assertListEqual([2, 3, 0, 1], indices.tolist())
32 |
33 | # Test reference order invariance
34 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references[::-1], s=4)
35 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
36 |
37 | def test_rank_samples_n_by_1(self):
38 | source_sequence = "This is a sample sentence"
39 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
40 | "Dies ist ein Test."]
41 | references = samples
42 | self.comet.compute_features({source_sequence} | set(samples) | set(references))
43 |
44 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references, s=1)
45 | self.assertEqual(0, indices[0]) # Perfect match with itself
46 | self.assertListEqual([0, 1, 3, 2], indices.tolist())
47 |
48 | # Test sample order invariance
49 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=1)
50 | self.assertListEqual([3, 2, 0, 1], indices.tolist())
51 |
52 | def test_rank_samples_aggregate(self):
53 | source_sequence = "This is a sample sentence"
54 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
55 | "Dies ist ein Test."]
56 | references = samples
57 | self.comet.compute_features({source_sequence} | set(samples) | set(references))
58 |
59 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, references, s=1)
60 | self.assertEqual(len(samples), len(indices))
61 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
62 |
63 | # Test sample order invariance
64 | indices = self.comet.rank_samples_aggregate(source_sequence, samples[::-1], references, s=1)
65 | self.assertListEqual([2, 3, 0, 1], indices.tolist())
66 |
67 | # Test reference order invariance
68 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, references[::-1], s=1)
69 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
70 |
71 | def test_rank_samples_aggregate_partial(self):
72 | source_sequence = "This is a sample sentence"
73 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
74 | "Dies ist ein Test."]
75 | references = samples
76 | self.comet.compute_features({source_sequence} | set(samples) | set(references))
77 |
78 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, references, s=2)
79 | self.assertEqual(len(samples), len(indices))
80 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
81 |
82 | # Test sample order invariance
83 | indices = self.comet.rank_samples_aggregate(source_sequence, samples[::-1], references, s=2)
84 | self.assertListEqual([2, 3, 0, 1], indices.tolist())
85 |
86 | # Test (partial) reference order invariance: change order of references within partitions
87 | indices = self.comet.rank_samples_aggregate(source_sequence, samples,
88 | references[:2][::-1] + references[2:][::-1], s=2)
89 | self.assertListEqual([1, 0, 3, 2], indices.tolist())
90 |
91 | def test_rank_samples_disaggregated_is_equivalent_to_n_by_n(self):
92 | source_sequence = "This is a sample sentence"
93 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.",
94 | "Dies ist ein Test."]
95 | references = samples
96 | self.comet.compute_features({source_sequence} | set(samples) | set(references))
97 |
98 | n_by_n_indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references, s=4)
99 | aggregate_indices = self.comet.rank_samples_aggregate(source_sequence, samples, references, s=4)
100 | self.assertListEqual(n_by_n_indices.tolist(), aggregate_indices.tolist())
101 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_epsilon_sampling.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest import TestCase
3 |
4 | from experiments.reference_aggregation.fairseq_utils import load_model
5 |
6 |
7 | class EpsilonSamplingTestCase(TestCase):
8 |
9 | def setUp(self):
10 | self.testset = "wmt21"
11 | self.language_pair = "en-de"
12 | self.test_dir = Path(__file__).parent / "out"
13 | self.test_dir.mkdir(exist_ok=True)
14 |
15 | def test_epsilon_sampling(self):
16 | model = load_model(self.language_pair)
17 | source_sentence = "This is a test."
18 | num_samples = 4
19 | # ε=0.02
20 | samples = model.sample(num_samples * [source_sentence], seed=42, sampling_epsilon_cutoff=0.02)
21 | self.assertEqual(len(samples), num_samples)
22 | self.assertIsInstance(samples[0], str)
23 | print(samples[0])
24 | # ε=0
25 | samples = model.sample(num_samples * [source_sentence], seed=42, sampling_epsilon_cutoff=0)
26 | self.assertEqual(len(samples), num_samples)
27 | self.assertIsInstance(samples[0], str)
28 |
29 | def test_extract_translations(self):
30 | # Generate samples
31 | from experiments.reference_aggregation.generate_samples import main as generate_samples
32 | jsonl_path = generate_samples(self.testset, self.language_pair, seed_no=0, num_samples=8, epsilon_cutoff=0.02,
33 | limit_segments=4, out_dir=self.test_dir)
34 | self.assertTrue(jsonl_path.exists())
35 | # Extract
36 | from experiments.reference_aggregation.baseline_epsilon_sampling import main
37 | out_path = main(self.testset, self.language_pair, num_samples=8, epsilon_cutoff=0.02, seed_no=0,
38 | out_dir=self.test_dir)
39 | self.assertTrue(out_path.exists())
40 | self.assertIn(self.test_dir, out_path.parents)
41 | self.assertTrue(out_path.name.endswith(".de"))
42 | translations = out_path.read_text().splitlines()
43 | self.assertEqual(len(translations), 4)
44 | print(translations[0])
45 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_generate_samples.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest import TestCase
3 |
4 | import jsonlines
5 |
6 |
7 | class GenerateSamplesTestCase(TestCase):
8 |
9 | def setUp(self):
10 | self.testset = "wmt21"
11 | self.language_pair = "en-de"
12 | self.test_dir = Path(__file__).parent / "out"
13 | self.test_dir.mkdir(exist_ok=True)
14 |
15 | def test_generate_samples(self):
16 | from experiments.reference_aggregation.generate_samples import main
17 | out_path = main(self.testset, self.language_pair, seed_no=0, num_samples=8, epsilon_cutoff=0.02,
18 | limit_segments=4, out_dir=self.test_dir)
19 | self.assertTrue(out_path.exists())
20 | self.assertIn(self.test_dir, out_path.parents)
21 | self.assertTrue(out_path.name.endswith(".jsonl"))
22 | with jsonlines.open(out_path) as f:
23 | data = list(f)
24 | self.assertEqual(len(data), 4)
25 | for line in data:
26 | self.assertIn("samples", line)
27 | self.assertEqual(len(line["samples"]), 8)
28 | self.assertIsInstance(line["samples"][0], str)
29 | print(data[0]["samples"][0])
30 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_run_mbr.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest import TestCase
3 |
4 | from experiments.reference_aggregation.run_mbr import main
5 |
6 |
7 | class MBRTestCase(TestCase):
8 |
9 | def setUp(self):
10 | self.testset = "wmt21"
11 | self.language_pair = "en-de"
12 | self.test_dir = Path(__file__).parent / "out"
13 | self.test_dir.mkdir(exist_ok=True)
14 |
15 | def test_run_mbr_pairwise_chrf(self):
16 | out_path = main(method="pairwise", topk=None, testset=self.testset, language_pair=self.language_pair, seed_no=0,
17 | fine_utility_name="chrf", num_samples=8, epsilon_cutoff=0.02, limit_segments=4,
18 | out_dir=self.test_dir)
19 | self.assertTrue(out_path.exists())
20 | self.assertIn(self.test_dir, out_path.parents)
21 | self.assertTrue(out_path.name.endswith(".de"))
22 | translations = out_path.read_text().splitlines()
23 | self.assertEqual(len(translations), 4)
24 | print(translations[0])
25 |
26 | def test_run_mbr_aggregate_chrf(self):
27 | out_path = main(method="aggregate", topk=None, testset=self.testset, language_pair=self.language_pair,
28 | seed_no=0, fine_utility_name="chrf", num_samples=8, epsilon_cutoff=0.02, limit_segments=4,
29 | out_dir=self.test_dir)
30 | translations = out_path.read_text().splitlines()
31 | self.assertEqual(len(translations), 4)
32 | print(translations[0])
33 |
34 | def test_run_mbr_aggregate_to_fine_chrf(self):
35 | out_path = main(method="aggregate_to_fine", topk=2, testset=self.testset, language_pair=self.language_pair,
36 | seed_no=0, fine_utility_name="chrf", num_samples=8, epsilon_cutoff=0.02, limit_segments=4,
37 | out_dir=self.test_dir)
38 | translations = out_path.read_text().splitlines()
39 | self.assertEqual(len(translations), 4)
40 | print(translations[0])
41 |
42 | def test_run_mbr_coarse_to_fine_chrf_to_comet22(self):
43 | out_path = main(method="coarse_to_fine", topk=2, testset=self.testset, language_pair=self.language_pair,
44 | seed_no=0, coarse_utility_name="chrf", fine_utility_name="cometinho", num_samples=8,
45 | epsilon_cutoff=0.02, limit_segments=4, out_dir=self.test_dir)
46 | translations = out_path.read_text().splitlines()
47 | self.assertEqual(len(translations), 4)
48 | print(translations[0])
49 |
50 | def test_run_mbr_aggregate_to_fine_chrf_to_comet22(self):
51 | out_path = main(method="aggregate_to_fine", topk=2, testset=self.testset, language_pair=self.language_pair,
52 | seed_no=0, coarse_utility_name="chrf", fine_utility_name="cometinho", num_samples=8,
53 | epsilon_cutoff=0.02, limit_segments=4, out_dir=self.test_dir)
54 | translations = out_path.read_text().splitlines()
55 | self.assertEqual(len(translations), 4)
56 | print(translations[0])
57 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_save_src_and_ref.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest import TestCase
3 |
4 |
5 | class SaveSrcAndRefTestCase(TestCase):
6 |
7 | def setUp(self):
8 | self.testset = "wmt21"
9 | self.language_pair = "en-de"
10 | self.test_dir = Path(__file__).parent / "out"
11 | self.test_dir.mkdir(exist_ok=True)
12 |
13 | def test_save_src_and_ref(self):
14 | from experiments.reference_aggregation.scripts.save_src_and_ref import main
15 | src_path, ref_path = main(self.testset, self.language_pair, limit_segments=4, out_dir=self.test_dir)
16 | self.assertTrue(src_path.exists())
17 | self.assertIn(self.test_dir, src_path.parents)
18 | self.assertTrue(src_path.name.endswith(".en"))
19 | self.assertTrue(ref_path.exists())
20 | self.assertIn(self.test_dir, ref_path.parents)
21 | self.assertTrue(ref_path.name.endswith(".de"))
22 | source_sentences = src_path.read_text().splitlines()
23 | self.assertEqual(len(source_sentences), 4)
24 | print(source_sentences[0])
25 | reference_sentences = ref_path.read_text().splitlines()
26 | self.assertEqual(len(reference_sentences), 4)
27 | print(reference_sentences[0])
28 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_testset.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from experiments.reference_aggregation.experiment_utils import Testset
4 |
5 |
6 | class TestsetTestCase(TestCase):
7 |
8 | def setUp(self):
9 | self.testsets = ["wmt21", "wmt22"]
10 | self.language_pairs = ["en-de", "de-en", "en-ru", "ru-en"]
11 |
12 | def test_load_testsets(self):
13 | for testset in self.testsets:
14 | for language_pair in self.language_pairs:
15 | data = Testset.from_wmt(testset, language_pair)
16 | self.assertEqual(language_pair, f"{data.src_lang}-{data.tgt_lang}")
17 | self.assertEqual(len(data.source_sentences), len(data.references))
18 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/tests/test_validation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from unittest import TestCase
4 |
5 | import jsonlines
6 |
7 | logging.basicConfig(level=logging.INFO)
8 |
9 |
10 | class ValidationTestCase(TestCase):
11 |
12 | def setUp(self):
13 | self.testset = "wmt21"
14 | self.language_pair = "en-de"
15 | self.test_dir = Path(__file__).parent / "out"
16 | self.test_dir.mkdir(exist_ok=True)
17 |
18 | def test_run_validation_cometinho(self):
19 | from experiments.reference_aggregation.validation import main
20 | jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=4, num_samples=8,
21 | limit_segments=4, out_dir=self.test_dir)
22 | self.assertTrue(jsonl_path.exists())
23 | with jsonlines.open(jsonl_path) as f:
24 | data = list(f)
25 |
26 | n_by_s_lines = [line for line in data if line["method"] == "n_by_s"]
27 | s_values = [line["s"] for line in n_by_s_lines]
28 | self.assertEqual([8, 4, 2, 1], s_values)
29 |
30 | aggregate_lines = [line for line in data if line["method"] == "aggregate"]
31 | s_values = [line["s"] for line in aggregate_lines]
32 | self.assertEqual([8, 4, 2, 1], s_values)
33 |
34 | for line in data:
35 | self.assertEqual(4, len(line["rankings"]))
36 | self.assertEqual(4, len(line["rankings"][0]))
37 | self.assertEqual(4, len(set(line["rankings"][0])))
38 |
39 | test_translation_paths = [
40 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s8.{self.language_pair.split('-')[1]}",
41 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s4.{self.language_pair.split('-')[1]}",
42 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s2.{self.language_pair.split('-')[1]}",
43 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s1.{self.language_pair.split('-')[1]}",
44 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s8.{self.language_pair.split('-')[1]}",
45 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s4.{self.language_pair.split('-')[1]}",
46 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s2.{self.language_pair.split('-')[1]}",
47 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s1.{self.language_pair.split('-')[1]}", ]
48 | for translation_path in test_translation_paths:
49 | self.assertTrue(translation_path.exists())
50 | self.assertIn(self.test_dir, translation_path.parents)
51 | self.assertTrue(translation_path.name.endswith(".de"))
52 | translations = translation_path.read_text().splitlines()
53 | self.assertEqual(len(translations), 4)
54 |
55 | def test_run_validation_topk_first_identical(self):
56 | """
57 | The top translation should be the same of any k
58 | """
59 | from experiments.reference_aggregation.validation import main
60 |
61 | # cometinho
62 | top1_jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=1,
63 | num_samples=8, limit_segments=1, out_dir=self.test_dir)
64 | with jsonlines.open(top1_jsonl_path) as f:
65 | top1_best_indices = [line["rankings"][0][0] for line in f]
66 | for k in [1, 2, 4, 8]:
67 | jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=k,
68 | num_samples=8, limit_segments=1, out_dir=self.test_dir)
69 | with jsonlines.open(jsonl_path) as f:
70 | best_indices = [line["rankings"][0][0] for line in f]
71 | self.assertEqual(top1_best_indices, best_indices)
72 |
73 | # chrf
74 | top1_jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="chrf", topk=1, num_samples=8,
75 | limit_segments=1, out_dir=self.test_dir)
76 | with jsonlines.open(top1_jsonl_path) as f:
77 | top1_best_indices = [line["rankings"][0][0] for line in f]
78 | for k in [1, 2, 4, 8]:
79 | jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="chrf", topk=k, num_samples=8,
80 | limit_segments=1, out_dir=self.test_dir)
81 | with jsonlines.open(jsonl_path) as f:
82 | best_indices = [line["rankings"][0][0] for line in f]
83 | self.assertEqual(top1_best_indices, best_indices)
84 |
85 | def test_plot_accuracy(self):
86 | # Run validation.py
87 | from experiments.reference_aggregation.validation import main as validation
88 | jsonl_path = validation(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=8,
89 | num_samples=8, limit_segments=4, out_dir=self.test_dir)
90 | self.assertTrue(jsonl_path.exists())
91 |
92 | # Top-8
93 | from experiments.reference_aggregation.plot_accuracy import main as plot_accuracy
94 | series_n_by_s_top8 = plot_accuracy(self.testset, self.language_pair, seed_no=0, fine_utility_name="cometinho",
95 | topk=8, method="n_by_s", num_samples=8, accuracy_topk=8, limit_segments=4,
96 | out_dir=self.test_dir)
97 | series_aggregate_top8 = plot_accuracy(self.testset, self.language_pair, seed_no=0,
98 | fine_utility_name="cometinho", topk=8, method="aggregate", num_samples=8,
99 | accuracy_topk=8, limit_segments=4, out_dir=self.test_dir)
100 |
101 | # Top-1
102 | series_n_by_s_top1 = plot_accuracy(self.testset, self.language_pair, seed_no=0, fine_utility_name="cometinho",
103 | topk=8, accuracy_topk=1, method="n_by_s", num_samples=8, limit_segments=4,
104 | out_dir=self.test_dir)
105 | series_aggregate_top1 = plot_accuracy(self.testset, self.language_pair, seed_no=0,
106 | fine_utility_name="cometinho", topk=8, accuracy_topk=1,
107 | method="aggregate", num_samples=8, limit_segments=4,
108 | out_dir=self.test_dir)
109 |
110 | # Assert that top-1 accuracy <= top-8 accuracy
111 | for (s, accuracy_top8), (_, accuracy_top1) in zip(series_n_by_s_top8, series_n_by_s_top1):
112 | self.assertLessEqual(accuracy_top1, accuracy_top8)
113 | for (s, accuracy_top8), (_, accuracy_top1) in zip(series_aggregate_top8, series_aggregate_top1):
114 | self.assertLessEqual(accuracy_top1, accuracy_top8)
115 |
116 | # Top-4
117 | series_n_by_s_top4 = plot_accuracy(self.testset, self.language_pair, seed_no=0, fine_utility_name="cometinho",
118 | topk=8, accuracy_topk=4, method="n_by_s", num_samples=8, limit_segments=4,
119 | out_dir=self.test_dir)
120 | series_aggregate_top4 = plot_accuracy(self.testset, self.language_pair, seed_no=0,
121 | fine_utility_name="cometinho", topk=8, accuracy_topk=4,
122 | method="aggregate", num_samples=8, limit_segments=4,
123 | out_dir=self.test_dir)
124 |
125 | # Assert that top-4 accuracy <= top-8 accuracy
126 | for (s, accuracy_top8), (_, accuracy_top4) in zip(series_n_by_s_top8, series_n_by_s_top4):
127 | self.assertLessEqual(accuracy_top4, accuracy_top8)
128 | for (s, accuracy_top8), (_, accuracy_top4) in zip(series_aggregate_top8, series_aggregate_top4):
129 | self.assertLessEqual(accuracy_top4, accuracy_top8)
130 |
131 | # Assert that top-1 accuracy <= top-4 accuracy
132 | for (s, accuracy_top4), (_, accuracy_top1) in zip(series_n_by_s_top4, series_n_by_s_top1):
133 | self.assertLessEqual(accuracy_top1, accuracy_top4)
134 | for (s, accuracy_top4), (_, accuracy_top1) in zip(series_aggregate_top4, series_aggregate_top1):
135 | self.assertLessEqual(accuracy_top1, accuracy_top4)
136 |
--------------------------------------------------------------------------------
/experiments/reference_aggregation/validation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | from pathlib import Path
4 | from typing import List
5 |
6 | import jsonlines
7 | from tqdm import tqdm
8 |
9 | from experiments.reference_aggregation.experiment_utils import Testset
10 | from experiments.reference_aggregation.mbr_utils import load_utility
11 |
12 |
13 | def main(testset: str, language_pair: str, seed_no: int, utility_name: str, chrf_eps_smoothing: bool = False,
14 | topk: int = 20, num_samples: int = 1024, epsilon_cutoff: float = 0.02, limit_segments: int = None,
15 | out_dir: Path = None) -> Path:
16 | if out_dir is None:
17 | out_dir = Path(__file__).parent
18 |
19 | assert topk <= num_samples
20 |
21 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments)
22 |
23 | samples_dir = out_dir / "samples"
24 | assert samples_dir.exists()
25 | samples_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl"
26 | assert samples_path.exists()
27 | with jsonlines.open(samples_path) as f:
28 | samples = [line["samples"] for line in f]
29 | samples = [sample[:num_samples] for sample in samples]
30 | if limit_segments is not None:
31 | samples = samples[:limit_segments]
32 |
33 | assert len(samples) == len(dataset.source_sentences)
34 | assert all(len(sample) == num_samples for sample in samples)
35 |
36 | references = samples
37 |
38 | # s = n/1, n/2, n/4, n/8, ..., n/n
39 | s_values = [int(num_samples / 2 ** i) for i in range(int(math.log2(num_samples)) + 1)]
40 | assert s_values[0] == num_samples
41 | assert s_values[-1] == 1
42 |
43 | utility = load_utility(utility_name)
44 |
45 | if utility_name == "chrf" and chrf_eps_smoothing:
46 | utility.eps_smoothing = True
47 |
48 | # Compute rankings for n-by-s and aggregate, for each s
49 | n_by_s_rankings: List[List[List[int]]] = [] # segments x s_values x topk
50 | aggregate_rankings: List[List[List[int]]] = [] # segments x s_values x topk
51 | for i in tqdm(list(range(len(dataset.source_sentences))), desc="segments"):
52 |
53 | # For COMET: compute embeddings
54 | if hasattr(utility, "compute_features"):
55 | utility.clear_features()
56 | input_sequences = {dataset.source_sentences[i]} | set(samples[i]) | set(references[i])
57 | utility.compute_features(input_sequences)
58 |
59 | n_by_s_rankings.append([])
60 | for s in s_values:
61 | n_by_s_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], samples[i], references[i], s=s)
62 | n_by_s_ranking = n_by_s_ranking[:topk]
63 | n_by_s_rankings[-1].append(n_by_s_ranking.tolist())
64 | aggregate_rankings.append([])
65 | for s in s_values:
66 | aggregate_ranking = utility.rank_samples_aggregate(dataset.source_sentences[i], samples[i], references[i],
67 | s=s)
68 | aggregate_ranking = aggregate_ranking[:topk]
69 | aggregate_rankings[-1].append(aggregate_ranking.tolist())
70 |
71 | # Save top-k rankings to jsonl file
72 | output_dir = out_dir / "validation_output"
73 | output_dir.mkdir(exist_ok=True)
74 | output_path = output_dir / f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{utility_name}{'-eps' if chrf_eps_smoothing else ''}.top{topk}.jsonl"
75 | with jsonlines.open(output_path, mode="w") as f:
76 | for i, s in enumerate(s_values):
77 | f.write({"method": "n_by_s", "s": s, "rankings": [ranking[i] for ranking in n_by_s_rankings]})
78 | for i, s in enumerate(s_values):
79 | f.write({"method": "aggregate", "s": s, "rankings": [ranking[i] for ranking in aggregate_rankings]})
80 |
81 | translations_dir = out_dir / "translations"
82 | translations_dir.mkdir(exist_ok=True)
83 | translations_prefix = f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{utility_name}{'-eps' if chrf_eps_smoothing else ''}"
84 |
85 | # Save top-1 translations for n-by-s
86 | for j, s in enumerate(s_values):
87 | n_by_s_translations_path = translations_dir / f"{translations_prefix}.n_by_s.s{s}.{dataset.tgt_lang}"
88 | with open(n_by_s_translations_path, "w") as f:
89 | for i, rankings in enumerate(n_by_s_rankings):
90 | ranking = rankings[j]
91 | f.write(samples[i][ranking[0]] + "\n")
92 |
93 | # Save top-1 translations for aggregate
94 | for j, s in enumerate(s_values):
95 | aggregate_translations_path = translations_dir / f"{translations_prefix}.aggregate.s{s}.{dataset.tgt_lang}"
96 | with open(aggregate_translations_path, "w") as f:
97 | for i, rankings in enumerate(aggregate_rankings):
98 | ranking = rankings[j]
99 | f.write(samples[i][ranking[0]] + "\n")
100 |
101 | return output_path
102 |
103 |
104 | if __name__ == '__main__':
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True)
107 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True)
108 | parser.add_argument('--seed', type=int, choices=range(10), required=True,
109 | help='Index of the random seed in the list of random seeds')
110 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True)
111 | parser.add_argument('--chrf-eps-smoothing', action='store_true',
112 | help='Use epsilon smoothing for ChrF (default: False = effective order smoothing)')
113 | parser.add_argument('--topk', type=int, default=20, help='Number of top translations to save in the jsonl file')
114 | parser.add_argument('--num-samples', type=int, default=1024)
115 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02)
116 | parser.add_argument('--limit-segments', type=int, default=None,
117 | help='Limit number of segments that are processed (used for testing)')
118 | args = parser.parse_args()
119 |
120 | jsonl_path = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed,
121 | utility_name=args.utility, chrf_eps_smoothing=args.chrf_eps_smoothing, topk=args.topk,
122 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff,
123 | limit_segments=args.limit_segments, )
124 | print(f"Saved results file to {jsonl_path}")
125 |
--------------------------------------------------------------------------------
/experiments/requirements.txt:
--------------------------------------------------------------------------------
1 | jsonlines==4.0.0
2 | datasets==2.14.6
3 | sacrebleu==2.3.1
4 | sacremoses==0.0.53 # For OpusMT
5 | nltk==3.8.1
6 | rouge_score==0.1.2
7 |
--------------------------------------------------------------------------------
/minimum-bayes-risk-decoding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/minimum-bayes-risk-decoding.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "mbr"
3 | version = "0.2.0"
4 | authors = [
5 | { name="Jannis Vamvas", email="vamvas@cl.uzh.ch" },
6 | ]
7 | description = "Minimum Bayes risk decoding for Hugging Face Transformers"
8 | readme = "README.md"
9 | requires-python = ">=3.9"
10 | dependencies = [
11 | "transformers<4.39",
12 | "evaluate",
13 | "cachetools",
14 | "tqdm",
15 | "fastchrf",
16 | ]
17 | classifiers = [
18 | "Programming Language :: Python :: 3",
19 | "License :: OSI Approved :: MIT License",
20 | "Operating System :: OS Independent",
21 | ]
22 |
23 | [project.urls]
24 | "Homepage" = "https://github.com/ZurichNLP/mbr"
25 | "Bug Tracker" = "https://github.com/ZurichNLP/mbr/issues"
26 | [build-system]
27 | requires = ["hatchling"]
28 | build-backend = "hatchling.build"
29 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | sacrebleu==2.4.0
2 | unbabel-comet==2.2.1
3 | git+https://github.com/google-research/bleurt.git
4 | sentencepiece==0.1.99 # M2M100 model
5 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | sacrebleu==2.4.0
2 | unbabel-comet==2.2.1
3 |
--------------------------------------------------------------------------------
/src/mbr/__init__.py:
--------------------------------------------------------------------------------
1 | from mbr.generation.configuration_utils import MBRConfig
2 | from mbr.generation.utils import MBROutput, MBRGenerationMixin
3 | from mbr.metrics.base import MetricOutput, MetricRunner
4 | from mbr.modeling import MBR
5 |
6 |
7 | __version__ = "0.2.0"
8 |
--------------------------------------------------------------------------------
/src/mbr/generation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/src/mbr/generation/__init__.py
--------------------------------------------------------------------------------
/src/mbr/generation/configuration_utils.py:
--------------------------------------------------------------------------------
1 | from transformers import __version__ as transformers_version
2 | from transformers.utils import logging
3 |
4 | logger = logging.get_logger(__name__)
5 |
6 |
7 | class MBRConfig:
8 | r"""
9 | Class that holds a configuration for minimum Bayes risk decoding (MBR). Pass this config when calling
10 | `MBRGenerationMixin.generate()`:
11 |
12 | Example:
13 |
14 | ```python
15 | >>> config = MBRConfig(num_samples=10, num_references=10, metric="fastchrf")
16 | >>> model.generate(..., mbr_config=config)
17 | ```
18 |
19 | The class is inspired by `transformers.GenerationConfig`.
20 | Note that `MBRConfig` does not control the sampling strategy. Pass separate `GenerationConfig` objects to control
21 | sampling:
22 |
23 | ```python
24 | >>> generation_config = GenerationConfig(do_sample=True, num_beams=1, top_p=0.9)
25 | >>> references_config = GenerationConfig(do_sample=True, num_beams=1, epsilon_cutoff=0.02)
26 | >>> model.generate(..., mbr_config=config, generation_config=generation_config, references_config=references_config)
27 | ```
28 |
29 | Arg:
30 | num_samples (`int`, *optional*, defaults to 10):
31 | Number of samples generated. 1 means no MBR decoding.
32 | num_references (`int`, *optional*, defaults to `num_samples`):
33 | Number of pseudo-references used for MBR decoding.
34 | metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'fastchrf'):
35 | Metric used for MBR decoding.
36 | metric_config_name (`str`, *optional*, defaults to None):
37 | Metric configuration to pass to `evaluate.load` (e.g., the model for a trained metric, such as
38 | "eamt22-cometinho-da"). If not specified, the default configuration is used.
39 | metric_output_field (`str`, *optional*, defaults to 'score'):
40 | Field of the metric output that is used
41 | metric_kwargs (optional):
42 | Additional arguments for the metric's `compute` method. The default MetricRunner requires it to be hashable.
43 | metric_cache_size (`int`, *optional*, defaults to `num_samples` * `num_references`):
44 | Size of the cache for the metric. Set to `None` to disable caching (not recommended).
45 | lower_is_better (`bool`, *optional*, defaults to `False`):
46 | Set to true if lower metric scores are better (e.g., perplexity).
47 |
48 | > Parameters that define the output variables of `generate`
49 |
50 | output_attentions (`bool`, *optional*, defaults to `False`):
51 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
52 | tensors for more details.
53 | output_hidden_states (`bool`, *optional*, defaults to `False`):
54 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
55 | more details.
56 | output_scores (`bool`, *optional*, defaults to `False`):
57 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
58 | output_all_samples (`bool`, *optional*, defaults to `False`):
59 | Whether or not to return all sampled sequences. See `all_sampled_sequences` under returned tensors for more
60 | details.
61 | output_reference_sequences (`bool`, *optional*, defaults to `False`):
62 | Whether or not to return the reference sequences. See `reference_sequences` under returned tensors for more
63 | details.
64 | output_metric_scores (`bool`, *optional*, defaults to `False`):
65 | Whether or not to return the metric scores. See `metric_scores` under returned tensors for more details.
66 | return_dict_in_generate (`bool`, *optional*, defaults to `False`):
67 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
68 | """
69 |
70 | def __init__(self, **kwargs):
71 | # Parameters that control the generation strategy used
72 | self.num_samples = kwargs.pop("num_samples", 10)
73 | self.num_references = kwargs.pop("num_references", self.num_samples)
74 | self.metric = kwargs.pop("metric", "fastchrf")
75 | self.metric_config_name = kwargs.pop("metric_config_name", None)
76 | self.metric_output_field = kwargs.pop("metric_output_field", "score")
77 | self.metric_kwargs = kwargs.pop("metric_kwargs", {})
78 | self.metric_cache_size = kwargs.pop("metric_cache_size", self.num_samples * self.num_references)
79 | self.lower_is_better = kwargs.pop("lower_is_better", False)
80 |
81 | # Parameters that define the output variables of `generate`
82 | self.output_attentions = kwargs.pop("output_attentions", False)
83 | self.output_hidden_states = kwargs.pop("output_hidden_states", False)
84 | self.output_scores = kwargs.pop("output_scores", False)
85 | self.output_all_samples = kwargs.pop("output_all_samples", False)
86 | self.output_reference_sequences = kwargs.pop("output_reference_sequences", False)
87 | self.output_metric_scores = kwargs.pop("output_metric_scores", False)
88 | self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
89 |
90 | # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
91 | # interface.
92 | self._from_model_config = kwargs.pop("_from_model_config", False)
93 | self._commit_hash = kwargs.pop("_commit_hash", None)
94 | self.transformers_version = kwargs.pop("transformers_version", transformers_version)
95 | import mbr
96 | self.mbr_version = kwargs.pop("mbr_version", mbr.__version__)
97 |
98 | # Additional attributes without default values
99 | if not self._from_model_config:
100 | # we don't want to copy values from the model config if we're initializing an `MBRConfig` from a
101 | # model's default configuration file
102 | for key, value in kwargs.items():
103 | try:
104 | setattr(self, key, value)
105 | except AttributeError as err:
106 | logger.error(f"Can't set {key} with value {value} for {self}")
107 | raise err
108 |
109 | # Validate the values of the attributes
110 | self.validate(is_init=True)
111 |
112 | def validate(self, is_init=False):
113 | """
114 | Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
115 | of parameterization that can be detected as incorrect from the configuration instance alone.
116 |
117 | Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
118 | model, such as parameters related to the generation length.
119 | """
120 | if self.metric_cache_size <= 0:
121 | raise ValueError(f"`metric_cache_size` ({self.metric_cache_size}) must be greater than 0.")
122 |
--------------------------------------------------------------------------------
/src/mbr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from mbr import MBRConfig
2 | from mbr.metrics.base import metric_is_source_based, MetricRunner
3 |
4 |
5 | def load_metric_runner(mbr_config: MBRConfig, tokenizer=None) -> MetricRunner:
6 | if mbr_config.metric in {"fastchrf", "aggregate_chrf", "fastchrf.aggregate_chrf"}:
7 | from mbr.metrics.fastchrf import FastChrfMetricRunner
8 | return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=False)
9 | elif mbr_config.metric in {"pairwise_chrf", "fastchrf.pairwise_chrf"}:
10 | from mbr.metrics.fastchrf import FastChrfMetricRunner
11 | return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=True)
12 | elif mbr_config.metric == "comet":
13 | from mbr.metrics.comet import CometMetricRunner
14 | return CometMetricRunner(mbr_config, tokenizer,
15 | device=0,
16 | batch_size_embed=64,
17 | batch_size_estimate=64,
18 | progress_bar=True,
19 | )
20 | elif mbr_config.metric == "aggregate_comet":
21 | from mbr.metrics.comet import AggregateCometMetricRunner
22 | mbr_config.metric = "comet"
23 | return AggregateCometMetricRunner(mbr_config, tokenizer,
24 | device=0,
25 | batch_size_embed=64,
26 | batch_size_estimate=64,
27 | progress_bar=True,
28 | )
29 | else:
30 | return MetricRunner(mbr_config, tokenizer)
31 |
--------------------------------------------------------------------------------
/src/mbr/metrics/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Tuple, Union, List, Optional
3 |
4 | import evaluate
5 | import torch
6 | from cachetools.func import fifo_cache
7 | from datasets import Metric
8 | from evaluate import EvaluationModule
9 | from transformers import PreTrainedTokenizerBase
10 | from transformers.utils import ModelOutput
11 |
12 | from mbr import MBRConfig
13 |
14 | MetricType = Union[Metric, EvaluationModule]
15 |
16 |
17 | @dataclass
18 | class MetricOutput(ModelOutput):
19 | """
20 | Args:
21 | scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`):
22 | The metric scores for each sample (aggregated over all references).
23 | scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`):
24 | The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level.
25 | """
26 | scores: torch.FloatTensor
27 | scores_per_reference: Optional[torch.FloatTensor] = None
28 |
29 |
30 | class MetricRunner:
31 | """
32 | Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample.
33 | This implementation uses the most basic approach, where samples and references are compared pairwise.
34 | Some metrics may support multi-reference evaluation or batching. Consider creating a subclass to make use of these
35 | features.
36 | """
37 |
38 | def __init__(self, mbr_config: MBRConfig, tokenizer: PreTrainedTokenizerBase):
39 | self.mbr_config = mbr_config
40 | # Ensure that mbr_config.metric_kwargs is hashable (because _compute_metric() uses lru_cache)
41 | if mbr_config.metric_kwargs:
42 | try:
43 | hash(tuple(self.mbr_config.metric_kwargs))
44 | except TypeError as e:
45 | raise TypeError(f"mbr_config.metric_kwargs must be hashable.") from e
46 | self.tokenizer = tokenizer
47 | self.metric = self._load_metric()
48 | self.metric_is_source_based = metric_is_source_based(self.metric)
49 | self._compute_metric_cached = fifo_cache(maxsize=self.mbr_config.metric_cache_size)(self._compute_metric)
50 |
51 | def _load_metric(self) -> MetricType:
52 | metric = self.mbr_config.metric
53 | if isinstance(metric, EvaluationModule):
54 | return metric
55 | elif isinstance(metric, str):
56 | metric = evaluate.load(metric, self.mbr_config.metric_config_name)
57 | else:
58 | raise ValueError(f"Invalid metric type: {type(metric)}")
59 | if metric.name == "comet":
60 | metric.scorer.eval()
61 | return metric
62 |
63 | def __call__(self,
64 | input_ids: torch.LongTensor,
65 | sample_ids: Tuple[torch.LongTensor],
66 | reference_ids: Tuple[torch.LongTensor],
67 | ) -> MetricOutput:
68 | r"""
69 | Args:
70 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
71 | The input sequence ids.
72 | sample_ids (`tuple(torch.LongTensor)`):
73 | Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing
74 | the sampled sequences.
75 | reference_ids:
76 | Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing
77 | the reference sequences.
78 |
79 | Returns:
80 | `MetricOutput` containing the metric scores.
81 | """
82 |
83 | # Detokenize
84 | if self.metric_is_source_based:
85 | str_inputs = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) # shape: (batch_size,)
86 | else:
87 | str_inputs = None
88 | str_samples = [] # num_samples x batch_size
89 | for sample in sample_ids:
90 | str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True))
91 | str_references = [] # num_references x batch_size
92 | for reference in reference_ids:
93 | str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True))
94 |
95 | if len(str_samples[0]) != len(str_references[0]):
96 | raise ValueError("Batch size of samples and references must match")
97 | if len(str_samples) != self.mbr_config.num_samples:
98 | raise ValueError("Number of samples must match `mbr_config.num_samples`")
99 | if len(str_references) != self.mbr_config.num_references:
100 | raise ValueError("Number of references must match `mbr_config.num_references`")
101 |
102 | # Compute metric
103 | scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs)
104 |
105 | return MetricOutput(
106 | scores=scores_per_reference.mean(dim=-1),
107 | scores_per_reference=scores_per_reference,
108 | )
109 |
110 | def _compute_str_metric(self,
111 | samples: List[List[str]],
112 | references: List[List[str]],
113 | inputs: List[str] = None,
114 | ) -> torch.FloatTensor:
115 | batch_size = len(samples[0])
116 | metric_scores = torch.zeros((batch_size, len(samples), len(references)))
117 | for i in range(batch_size):
118 | for j in range(len(samples)):
119 | sample = samples[j][i]
120 | for k in range(len(references)):
121 | reference = references[k][i]
122 | if inputs is not None:
123 | score = self.compute_metric(
124 | sources=(inputs[i],),
125 | predictions=(sample,),
126 | references=(reference,),
127 | **self.mbr_config.metric_kwargs,
128 | )
129 | else:
130 | score = self.compute_metric(
131 | predictions=(sample,),
132 | references=(reference,),
133 | **self.mbr_config.metric_kwargs,
134 | )
135 | metric_scores[i, j, k] = score
136 | return metric_scores
137 |
138 | def _compute_metric(self, *args, **kwargs) -> float:
139 | # Call _compute() instead of compute() for performance reasons.
140 | # Since we are comparing individual samples, we do not need the overhead of compute().
141 | output = self.metric._compute(*args, **kwargs)
142 | if self.mbr_config.metric_output_field not in output:
143 | raise ValueError(f"Metric output does not contain '{self.mbr_config.metric_output_field}' "
144 | f"Use `mbr_config.metric_output_field` to specify the correct field. "
145 | f"Available fields: {list(output.keys())}"
146 | )
147 | score = output[self.mbr_config.metric_output_field]
148 | if isinstance(score, list):
149 | score = score[0]
150 | return score
151 |
152 | def compute_metric(self, *args, **kwargs) -> float:
153 | return self._compute_metric_cached(*args, **kwargs)
154 |
155 |
156 | def metric_is_source_based(metric: MetricType) -> bool:
157 | return "sources" in metric.features
158 |
--------------------------------------------------------------------------------
/src/mbr/metrics/comet.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import List, Tuple, Dict, Set
3 |
4 | import torch
5 | from cachetools import FIFOCache
6 | from comet.models import RegressionMetric
7 | from tqdm import tqdm
8 |
9 | from mbr import MetricRunner, MetricOutput
10 |
11 |
12 | class CometMetricRunner(MetricRunner):
13 | """
14 | Efficient usage of COMET for MBR, based on https://github.com/Unbabel/COMET
15 |
16 | The implementation is inspired by https://github.com/chanberg/COMET-mbr and
17 | https://github.com/Unbabel/COMET/blob/master/comet/cli/mbr.py
18 | """
19 |
20 | def __init__(self,
21 | *args,
22 | device=None,
23 | batch_size_embed: int = 1,
24 | batch_size_estimate: int = 1,
25 | progress_bar: bool = False,
26 | **kwargs,
27 | ):
28 | super().__init__(*args, **kwargs)
29 | if self.metric.__class__.__name__ != "COMET":
30 | raise ValueError(
31 | f"CometMetricRunner expects an evaluate.COMET metric, got {self.metric.__class__.__name__}")
32 | self.comet = self.metric
33 | if self.mbr_config.metric_output_field not in ["mean_score", "scores"]:
34 | raise ValueError(f"CometMetricRunner expects metric_output_field to be 'mean_score' or 'scores', "
35 | f"got {self.mbr_config.metric_output_field}")
36 | if self.mbr_config.metric_kwargs:
37 | raise NotImplementedError("CometMetricRunner does not support metric_kwargs")
38 | if not isinstance(self.comet.scorer, RegressionMetric):
39 | raise NotImplementedError("CometMetricRunner only supports COMET models that are an instance of "
40 | "comet.models.RegressionMetric")
41 | if device is not None:
42 | self.comet.scorer = self.comet.scorer.to(device)
43 | self.comet.scorer.eval()
44 | self.batch_size_embed = batch_size_embed
45 | self.batch_size_estimate = batch_size_estimate
46 | self.progress_bar = progress_bar
47 | # We use a key-value cache, which is needed if the metric is called multiple times with similar inputs
48 | # (e.g. for MBR with iterative pruning).
49 | self.embedding_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size)
50 | self.score_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size)
51 |
52 | @torch.no_grad()
53 | def _compute_str_metric(self,
54 | samples: List[List[str]],
55 | references: List[List[str]],
56 | inputs: List[str] = None,
57 | ) -> torch.FloatTensor:
58 | if inputs is None:
59 | raise NotImplementedError("CometMetricRunner requires source sequences (`inputs`) to be provided")
60 | batch_size = len(samples[0])
61 | metric_scores = torch.zeros((batch_size, len(samples), len(references)))
62 | for i in tqdm(list(range(batch_size)), desc="comet", disable=not self.progress_bar):
63 | # Embed all sequences
64 | all_samples = [sample[i] for sample in samples]
65 | all_references = [reference[i] for reference in references]
66 | all_sequences = set(all_samples + all_references + inputs)
67 |
68 | all_embeddings: Dict[str, torch.FloatTensor] = {}
69 | # Populate embeddings from cache
70 | for sequence in list(all_sequences):
71 | if sequence in self.embedding_cache:
72 | all_embeddings[sequence] = self.embedding_cache[sequence]
73 | all_sequences.remove(sequence)
74 |
75 | # Compute embeddings for remaining sequences
76 | if all_sequences:
77 | all_sequences = list(all_sequences)
78 | encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device)
79 | batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed),
80 | range(self.batch_size_embed, len(all_sequences), self.batch_size_embed))
81 | for start_idx, end_idx in batches:
82 | embeddings = self.comet.scorer.get_sentence_embedding(
83 | input_ids=encodings["input_ids"][start_idx:end_idx],
84 | attention_mask=encodings["attention_mask"][start_idx:end_idx],
85 | )
86 | for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)):
87 | embedding = embeddings[j - start_idx]
88 | all_embeddings[all_sequences[j]] = embedding
89 | self.embedding_cache[all_sequences[j]] = embedding
90 |
91 | # Collect all input triples in a list
92 | input_triples: Set[Tuple[str, str, str]] = set()
93 | for j in range(len(samples)):
94 | for k in range(len(references)):
95 | input_triples.add((inputs[i], samples[j][i], references[k][i]))
96 |
97 | input_triple_scores: Dict[Tuple[str, str, str], torch.FloatTensor] = {}
98 | # Populate scores from cache
99 | for triple in list(input_triples):
100 | if triple in self.score_cache:
101 | input_triple_scores[triple] = self.score_cache[triple]
102 | input_triples.remove(triple)
103 |
104 | # Compute scores for remaining input triples
105 | input_triples: List = list(input_triples)
106 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate),
107 | range(self.batch_size_estimate, len(input_triples),
108 | self.batch_size_estimate))
109 | for start_idx, end_idx in batches:
110 | batch = input_triples[start_idx:end_idx]
111 | batch_scores = self.comet.scorer.estimate(
112 | src_sentemb=torch.stack([all_embeddings[triple[0]] for triple in batch]),
113 | mt_sentemb=torch.stack([all_embeddings[triple[1]] for triple in batch]),
114 | ref_sentemb=torch.stack([all_embeddings[triple[2]] for triple in batch]),
115 | )
116 | for j in range(start_idx, end_idx if end_idx is not None else len(input_triples)):
117 | triple = batch[j - start_idx]
118 | score = batch_scores.score[j - start_idx]
119 | input_triple_scores[triple] = score
120 | self.score_cache[triple] = score
121 |
122 | for j in range(len(samples)):
123 | for k in range(len(references)):
124 | metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])]
125 |
126 | return metric_scores
127 |
128 |
129 | class AggregateCometMetricRunner(CometMetricRunner):
130 | """
131 | Implements reference aggregation as described in "Linear-time Minimum Bayes Risk Decoding with Reference Aggregation"
132 | (Vamvas & Sennrich, 2024) https://arxiv.org/abs/2402.04251
133 | """
134 |
135 | def __call__(self,
136 | input_ids: torch.LongTensor,
137 | sample_ids: Tuple[torch.LongTensor],
138 | reference_ids: Tuple[torch.LongTensor],
139 | ) -> MetricOutput:
140 | r"""
141 | Args:
142 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
143 | The input sequence ids.
144 | sample_ids (`tuple(torch.LongTensor)`):
145 | Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing
146 | the sampled sequences.
147 | reference_ids:
148 | Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing
149 | the reference sequences.
150 |
151 | Returns:
152 | `MetricOutput` containing the metric scores.
153 | """
154 |
155 | # Detokenize
156 | if self.metric_is_source_based:
157 | str_inputs = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) # shape: (batch_size,)
158 | else:
159 | str_inputs = None
160 | str_samples = [] # num_samples x batch_size
161 | for sample in sample_ids:
162 | str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True))
163 | str_references = [] # num_references x batch_size
164 | for reference in reference_ids:
165 | str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True))
166 |
167 | if len(str_samples[0]) != len(str_references[0]):
168 | raise ValueError("Batch size of samples and references must match")
169 |
170 | # Compute metric
171 | scores = self._compute_str_metric(str_samples, str_references, str_inputs)
172 |
173 | return MetricOutput(
174 | scores=scores,
175 | scores_per_reference=None,
176 | )
177 |
178 | @torch.no_grad()
179 | def _compute_str_metric(self,
180 | samples: List[List[str]],
181 | references: List[List[str]],
182 | inputs: List[str] = None,
183 | ) -> torch.FloatTensor:
184 | if inputs is None:
185 | raise NotImplementedError("CometMetricRunner requires source sequences (`inputs`) to be provided")
186 | batch_size = len(samples[0])
187 | metric_scores = torch.zeros((batch_size, len(samples)))
188 | for i in tqdm(list(range(batch_size)), desc="comet", disable=not self.progress_bar):
189 | # Embed all sequences
190 | all_samples = [sample[i] for sample in samples]
191 | all_references = [reference[i] for reference in references]
192 | all_sequences = set(all_samples + all_references + inputs)
193 |
194 | all_embeddings: Dict[str, torch.FloatTensor] = {}
195 | # Populate embeddings from cache
196 | for sequence in list(all_sequences):
197 | if sequence in self.embedding_cache:
198 | all_embeddings[sequence] = self.embedding_cache[sequence]
199 | all_sequences.remove(sequence)
200 |
201 | # Compute embeddings for remaining sequences
202 | if all_sequences:
203 | all_sequences = list(all_sequences)
204 | encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device)
205 | batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed),
206 | range(self.batch_size_embed, len(all_sequences), self.batch_size_embed))
207 | for start_idx, end_idx in batches:
208 | embeddings = self.comet.scorer.get_sentence_embedding(
209 | input_ids=encodings["input_ids"][start_idx:end_idx],
210 | attention_mask=encodings["attention_mask"][start_idx:end_idx],
211 | )
212 | for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)):
213 | embedding = embeddings[j - start_idx]
214 | all_embeddings[all_sequences[j]] = embedding
215 | self.embedding_cache[all_sequences[j]] = embedding
216 |
217 | # Compute average reference embedding
218 | avg_reference_embedding = torch.stack([all_embeddings[reference] for reference in all_references]).mean(dim=0)
219 |
220 | # Collect all input triples in a list
221 | input_triples: Set[Tuple[str, str, str]] = set()
222 | for j in range(len(samples)):
223 | input_triples.add((inputs[i], samples[j][i], "avg"))
224 |
225 | input_triple_scores: Dict[Tuple[str, str, str], torch.FloatTensor] = {}
226 | # Populate scores from cache
227 | for triple in list(input_triples):
228 | if triple in self.score_cache:
229 | input_triple_scores[triple] = self.score_cache[triple]
230 | input_triples.remove(triple)
231 |
232 | # Compute scores for remaining input triples
233 | input_triples: List = list(input_triples)
234 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate),
235 | range(self.batch_size_estimate, len(input_triples),
236 | self.batch_size_estimate))
237 | for start_idx, end_idx in batches:
238 | batch = input_triples[start_idx:end_idx]
239 | batch_scores = self.comet.scorer.estimate(
240 | src_sentemb=torch.stack([all_embeddings[triple[0]] for triple in batch]),
241 | mt_sentemb=torch.stack([all_embeddings[triple[1]] for triple in batch]),
242 | ref_sentemb=avg_reference_embedding.unsqueeze(0).repeat(len(batch), 1),
243 | )
244 | for j in range(start_idx, end_idx if end_idx is not None else len(input_triples)):
245 | triple = batch[j - start_idx]
246 | score = batch_scores.score[j - start_idx]
247 | input_triple_scores[triple] = score
248 | self.score_cache[triple] = score
249 |
250 | for j in range(len(samples)):
251 | metric_scores[i, j] = input_triple_scores[(inputs[i], samples[j][i], "avg")]
252 |
253 | return metric_scores
254 |
--------------------------------------------------------------------------------
/src/mbr/metrics/fastchrf.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import torch
4 | from fastchrf import pairwise_chrf, aggregate_chrf
5 | from transformers import PreTrainedTokenizerBase
6 |
7 | from mbr import MetricRunner, MBRConfig, MetricOutput
8 |
9 |
10 | class FastChrfMetricRunner(MetricRunner):
11 | """
12 | MetricRunner for fastChrF. See https://github.com/jvamvas/fastChrF for more information.
13 |
14 | Args:
15 | mbr_config
16 | tokenizer
17 | compute_pairwise_average: Default: False. If True, use fastchr.chrf_pairwise() to calculate exact ChrF scores
18 | for each sample-reference pair and then average them; this corresponds to a fast implementation of the
19 | original ChrF metric. If False, use fastchr.chrf_aggregate() to directly calculate aggregate fastChrF scores
20 | across all references; note that the result will be different from the original ChrF metric.
21 | """
22 |
23 | def __init__(self,
24 | mbr_config: MBRConfig,
25 | tokenizer: PreTrainedTokenizerBase,
26 | compute_pairwise_average: bool = False,
27 | ):
28 | self.mbr_config = mbr_config
29 | self.tokenizer = tokenizer
30 | self.metric_is_source_based = False
31 | self.char_order = mbr_config.metric_kwargs.get("char_order", 6)
32 | self.beta = mbr_config.metric_kwargs.get("beta", 2)
33 | self.remove_whitespace = mbr_config.metric_kwargs.get("remove_whitespace", True)
34 | self.eps_smoothing = mbr_config.metric_kwargs.get("eps_smoothing", False)
35 | self.compute_pairwise_average = compute_pairwise_average
36 |
37 | def __call__(self,
38 | input_ids: torch.LongTensor,
39 | sample_ids: Tuple[torch.LongTensor],
40 | reference_ids: Tuple[torch.LongTensor],
41 | ) -> MetricOutput:
42 | r"""
43 | Args:
44 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
45 | The input sequence ids.
46 | sample_ids (`tuple(torch.LongTensor)`):
47 | Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing
48 | the sampled sequences.
49 | reference_ids:
50 | Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing
51 | the reference sequences.
52 |
53 | Returns:
54 | `MetricOutput` containing the metric scores.
55 | """
56 |
57 | # Detokenize
58 | str_samples = [] # num_samples x batch_size
59 | for sample in sample_ids:
60 | str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True))
61 | str_references = [] # num_references x batch_size
62 | for reference in reference_ids:
63 | str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True))
64 |
65 | if len(str_samples[0]) != len(str_references[0]):
66 | raise ValueError("Batch size of samples and references must match")
67 | if len(str_samples) != self.mbr_config.num_samples:
68 | raise ValueError("Number of samples must match `mbr_config.num_samples`")
69 | if len(str_references) != self.mbr_config.num_references:
70 | raise ValueError("Number of references must match `mbr_config.num_references`")
71 |
72 | # Transpose to batch_size x num_samples/num_references
73 | str_samples = list(zip(*str_samples))
74 | str_references = list(zip(*str_references))
75 |
76 | if self.compute_pairwise_average:
77 | output = self._compute_pairwise_chrf(str_samples, str_references)
78 | else:
79 | output = self._compute_aggregate_chrf(str_samples, str_references)
80 | return output
81 |
82 | def _compute_pairwise_chrf(self, samples: List[List[str]], references: List[List[str]]) -> MetricOutput:
83 | scores_per_reference = pairwise_chrf(
84 | samples,
85 | references,
86 | char_order=self.char_order,
87 | beta=self.beta,
88 | remove_whitespace=self.remove_whitespace,
89 | eps_smoothing=self.eps_smoothing,
90 | )
91 | scores_per_reference = torch.tensor(scores_per_reference)
92 | scores = scores_per_reference.mean(dim=-1)
93 | return MetricOutput(
94 | scores=scores,
95 | scores_per_reference=scores_per_reference,
96 | )
97 |
98 | def _compute_aggregate_chrf(self, samples: List[List[str]], references: List[List[str]]) -> MetricOutput:
99 | scores = aggregate_chrf(
100 | samples,
101 | references,
102 | char_order=self.char_order,
103 | beta=self.beta,
104 | remove_whitespace=self.remove_whitespace,
105 | eps_smoothing=self.eps_smoothing,
106 | )
107 | scores = torch.tensor(scores)
108 | return MetricOutput(
109 | scores=scores,
110 | scores_per_reference=None,
111 | )
112 |
--------------------------------------------------------------------------------
/src/mbr/modeling.py:
--------------------------------------------------------------------------------
1 | from transformers import GenerationMixin
2 |
3 | from mbr.generation.utils import MBRGenerationMixin
4 |
5 |
6 | def MBR(model_class: type) -> type:
7 | """
8 | Utility function for converting a model class into a class that inherits from `~generation.MBRGenerationMixin`.
9 | """
10 | if not issubclass(model_class, GenerationMixin):
11 | raise ValueError(
12 | f"MBR() can only be applied to classes that inherit from `transformers.GenerationMixin`, "
13 | f"but got {model_class}."
14 | )
15 | return type("MBR" + model_class.__name__, (MBRGenerationMixin, model_class), {})
16 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_config.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from mbr import MBRConfig
4 |
5 |
6 | class MBRConfigTestCase(TestCase):
7 |
8 | def test_default_config(self):
9 | config = MBRConfig()
10 | self.assertEqual(config.num_samples, 10)
11 | self.assertEqual(config.num_references, 10)
12 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from unittest import TestCase
4 |
5 | import evaluate
6 | import torch
7 | from transformers import AutoTokenizer
8 |
9 | from mbr import MetricRunner, MBRConfig
10 | from mbr.metrics import metric_is_source_based
11 |
12 |
13 | class MetricUtilsTestCase(TestCase):
14 |
15 | def setUp(self):
16 | self.mbr_config = MBRConfig(
17 | metric="chrf",
18 | metric_output_field="score",
19 | num_samples=3,
20 | num_references=2,
21 | )
22 | self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
23 | self.tokenizer.pad_token = self.tokenizer.eos_token
24 | self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
25 | self.inputs = [ # shape: (batch_size,)
26 | "This is an input sentence.",
27 | "This is another input sentence.",
28 | ]
29 | self.samples = [ # num_samples x batch_size
30 | ["This is a sample sentence.", "Something totally different."],
31 | ["This is a sample sentence.", "This a third sample sentence."],
32 | ["Something totally different.", "This is a fourth sample sentence."],
33 | ]
34 | self.references = [ # num_references x batch_size
35 | ["This is a reference sentence.", "This is another reference sentence."],
36 | ["This is a reference sentence.", "This is a fourth reference sentence."],
37 | ]
38 | self.input_ids = self.tokenizer(self.inputs, return_tensors="pt", padding=True).input_ids
39 | self.sample_ids = tuple([self.tokenizer(sample, return_tensors="pt", padding=True).input_ids for sample in self.samples])
40 | self.reference_ids = tuple([self.tokenizer(reference, return_tensors="pt", padding=True).input_ids for reference in self.references])
41 |
42 | def test_is_source_based__chrf(self):
43 | chrf = evaluate.load("chrf")
44 | self.assertFalse(metric_is_source_based(chrf))
45 |
46 | def test_is_source_based__comet(self):
47 | comet = evaluate.load("comet", "eamt22-cometinho-da")
48 | self.assertTrue(metric_is_source_based(comet))
49 |
50 | @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
51 | def test_is_source_based__bleurt(self):
52 | bleurt = evaluate.load("bleurt")
53 | self.assertFalse(metric_is_source_based(bleurt))
54 |
55 | def test_load_metric(self):
56 | self.mbr_config.metric = "chrf"
57 | metric = self.metric_runner._load_metric()
58 | self.assertIsInstance(metric, evaluate.Metric)
59 | self.assertEqual(metric.name, "chr_f")
60 | self.mbr_config.metric = evaluate.load("chrf")
61 | metric = self.metric_runner._load_metric()
62 | self.assertIsInstance(metric, evaluate.Metric)
63 | self.assertEqual(metric.name, "chr_f")
64 |
65 | def test_metric_config_name(self):
66 | self.mbr_config.metric = "comet"
67 | self.mbr_config.metric_config_name = "eamt22-cometinho-da"
68 | self.mbr_config.metric_output_field = "mean_score"
69 | metric = self.metric_runner._load_metric()
70 | self.assertIsInstance(metric, evaluate.Metric)
71 | self.assertEqual(metric.name, "comet")
72 | # Test custom metric_config_name
73 | self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder")
74 |
75 | def test_compute_metric__chrf(self):
76 | metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
77 | self.assertTrue(torch.is_floating_point(metric_output.scores))
78 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
79 | torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
80 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
81 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
82 | # Duplicate samples should have the same scores
83 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
84 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
85 | # The metric scores should rank as expected, given the test strings in self.samples and self.references
86 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
87 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
88 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
89 |
90 | def test_compute_metric__comet(self):
91 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
92 | self.mbr_config.metric.scorer.eval()
93 | self.mbr_config.metric_output_field = "mean_score"
94 | self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
95 | self.assertEqual(self.metric_runner.metric.name, "comet")
96 | metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
97 | self.assertTrue(torch.is_floating_point(metric_output.scores))
98 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
99 | torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
100 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
101 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
102 | # Duplicate samples should have the same scores
103 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
104 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
105 | # The metric scores should rank as expected, given the test strings in self.samples and self.references
106 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
107 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
108 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
109 |
110 | @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
111 | def test_compute_metric__bleurt(self):
112 | self.mbr_config.metric = evaluate.load("bleurt")
113 | self.mbr_config.metric_output_field = "scores"
114 | self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
115 | self.assertEqual(self.metric_runner.metric.name, "bleurt")
116 | metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
117 | self.assertTrue(torch.is_floating_point(metric_output.scores))
118 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
119 | torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores)
120 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
121 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
122 | # Duplicate samples should have the same scores
123 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
124 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
125 | # The metric scores should rank as expected, given the test strings in self.samples and self.references
126 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
127 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
128 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
129 |
130 | def test_comet_metric_runner(self):
131 | from mbr.metrics.comet import CometMetricRunner
132 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
133 | self.mbr_config.metric.scorer.eval()
134 | self.mbr_config.metric_output_field = "mean_score"
135 | base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
136 | self.assertEqual(base_metric_runner.metric.name, "comet")
137 | self.assertFalse(base_metric_runner.metric.scorer.training)
138 | comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer)
139 | self.assertFalse(comet_metric_runner.metric.scorer.training)
140 | # Output should be the same as the base MetricRunner
141 | base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
142 | metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
143 | torch.testing.assert_close(base_metric_scores, metric_scores)
144 |
145 | def test_comet_metric_runner__cache(self):
146 | """Output should be identical irrespective of cache size"""
147 | from mbr.metrics.comet import CometMetricRunner
148 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
149 | self.mbr_config.metric_output_field = "mean_score"
150 | base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
151 | base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
152 | self.assertEqual(base_metric_runner.metric.name, "comet")
153 | for cache_size in [1, 4, 8]:
154 | self.mbr_config.metric_cache_size = cache_size
155 | comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer)
156 | metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
157 | torch.testing.assert_close(base_metric_scores, metric_scores)
158 |
159 | def test_comet_metric_runner__aggregate(self):
160 | from mbr.metrics.comet import AggregateCometMetricRunner
161 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da")
162 | self.mbr_config.metric.scorer.eval()
163 | self.mbr_config.metric_output_field = "mean_score"
164 | base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer)
165 | self.assertEqual(base_metric_runner.metric.name, "comet")
166 | self.assertFalse(base_metric_runner.metric.scorer.training)
167 | comet_metric_runner = AggregateCometMetricRunner(self.mbr_config, self.tokenizer)
168 | self.assertFalse(comet_metric_runner.metric.scorer.training)
169 | metric_output = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
170 | self.assertTrue(torch.is_floating_point(metric_output.scores))
171 | self.assertIsNone(metric_output.scores_per_reference)
172 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
173 | # Duplicate samples should have the same scores
174 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
175 | # The metric scores should rank as expected, given the test strings in self.samples and self.references
176 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
177 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
178 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
179 |
180 | def test_fastchrf_metric_runner__aggregate(self):
181 | from mbr.metrics.fastchrf import FastChrfMetricRunner
182 | metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=False)
183 | metric_output = metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
184 | self.assertTrue(torch.is_floating_point(metric_output.scores))
185 | self.assertIsNone(metric_output.scores_per_reference)
186 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
187 | # Duplicate samples should have the same scores
188 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
189 | # The metric scores should rank as expected, given the test strings in self.samples and self.references
190 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
191 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
192 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
193 |
194 | def test_fastchrf_metric_runner__pairwise(self):
195 | from mbr.metrics.fastchrf import FastChrfMetricRunner
196 | metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=True)
197 | metric_output = metric_runner(self.input_ids, self.sample_ids, self.reference_ids)
198 | self.assertTrue(torch.is_floating_point(metric_output.scores))
199 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference))
200 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples
201 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references
202 | # Duplicate samples should have the same scores
203 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1])
204 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0])
205 | # The metric scores should rank as expected, given the test strings in self.samples and self.references
206 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2])
207 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1])
208 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2])
209 |
--------------------------------------------------------------------------------
/tests/test_pipelines.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from unittest import TestCase
4 |
5 | from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel, M2M100ForConditionalGeneration, set_seed
6 |
7 | from mbr import MBRConfig
8 | from mbr import MBR
9 |
10 |
11 | class TextGenerationTestCase(TestCase):
12 |
13 | def setUp(self):
14 | set_seed(42)
15 | self.model = MBR(GPT2LMHeadModel).from_pretrained("distilgpt2").eval()
16 | self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
17 | self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
18 |
19 | def test_pipeline(self):
20 | mbr_config = MBRConfig(
21 | num_samples=5,
22 | )
23 | output = self.pipeline(
24 | "Hello,",
25 | mbr_config=mbr_config,
26 | tokenizer=self.tokenizer,
27 | )
28 | self.assertEqual(1, len(output))
29 | self.assertIn("generated_text", output[0])
30 |
31 |
32 | @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
33 | class TranslationTestCase(TestCase):
34 |
35 | def setUp(self):
36 | set_seed(42)
37 | self.model = MBR(M2M100ForConditionalGeneration).from_pretrained("alirezamsh/small100").eval()
38 | self.tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100")
39 | self.pipeline = pipeline("translation_en_to_fr", model=self.model, tokenizer=self.tokenizer)
40 | self.tokenizer.tgt_lang = "fr"
41 |
42 | def test_pipeline(self):
43 | mbr_config = MBRConfig(
44 | num_samples=5,
45 | )
46 | output = self.pipeline(
47 | "Could you translate this for me, please?",
48 | mbr_config=mbr_config,
49 | tokenizer=self.tokenizer,
50 | do_sample=True,
51 | num_beams=1,
52 | )
53 | self.assertEqual(1, len(output))
54 | self.assertIn("translation_text", output[0])
55 |
--------------------------------------------------------------------------------