├── .github └── workflows │ └── unittest.yml ├── .gitignore ├── LICENSE ├── README.md ├── experiments ├── README.md ├── bertsch-et-al-2023-mbr │ ├── README.md │ ├── results │ │ └── figures │ │ │ ├── CNN DM original.png │ │ │ └── CNN DM reproduction.png │ └── run_experiment.py ├── chrf-vs-fastchrf │ ├── README.md │ └── run_experiment.py ├── freitag-et-al-2023-epsilon │ ├── README.md │ ├── results │ │ └── figures │ │ │ ├── All Results DE–EN (original).png │ │ │ ├── All Results DE–EN (reproduction).png │ │ │ ├── All Results EN–DE (original).png │ │ │ ├── All Results EN–DE (reproduction).png │ │ │ ├── Main Comparison DE–EN (original).png │ │ │ ├── Main Comparison DE–EN (reproduction).png │ │ │ ├── Main Comparison EN–DE (original).png │ │ │ └── Main Comparison EN–DE (reproduction).png │ └── run_experiment.py ├── müller-sennrich-2021-understanding │ ├── README.md │ ├── results │ │ └── figures │ │ │ ├── AZE–ENG (original).png │ │ │ ├── AZE–ENG (reproduction).png │ │ │ ├── BEL–RUS (original).png │ │ │ ├── BEL–RUS (reproduction).png │ │ │ ├── DAN–EPO (original).png │ │ │ ├── DAN–EPO (reproduction).png │ │ │ ├── DEU–FRA (original).png │ │ │ └── DEU–FRA (reproduction).png │ └── run_experiment.py ├── reference_aggregation │ ├── README.md │ ├── baseline_beam_search.py │ ├── baseline_epsilon_sampling.py │ ├── experiment_utils.py │ ├── fairseq_utils.py │ ├── generate_samples.py │ ├── mbr_utils.py │ ├── plot_accuracy.py │ ├── requirements.txt │ ├── results │ │ ├── chrf.log │ │ ├── comet-xl.log │ │ ├── comet22.log │ │ └── cometinho.log │ ├── run_mbr.py │ ├── scripts │ │ ├── benchmark_time.sh │ │ ├── evaluate-bleurt.ipynb │ │ ├── evaluate-chrf.sh │ │ ├── evaluate-comet.sh │ │ ├── plot_accuracy_reverse.py │ │ ├── print_data_stats_table.py │ │ └── save_src_and_ref.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_beam_search.py │ │ ├── test_chrf.py │ │ ├── test_comet.py │ │ ├── test_epsilon_sampling.py │ │ ├── test_generate_samples.py │ │ ├── test_run_mbr.py │ │ ├── test_save_src_and_ref.py │ │ ├── test_testset.py │ │ └── test_validation.py │ └── validation.py └── requirements.txt ├── minimum-bayes-risk-decoding.png ├── pyproject.toml ├── requirements-dev.txt ├── requirements-test.txt ├── src └── mbr │ ├── __init__.py │ ├── generation │ ├── __init__.py │ ├── configuration_utils.py │ └── utils.py │ ├── metrics │ ├── __init__.py │ ├── base.py │ ├── comet.py │ └── fastchrf.py │ └── modeling.py └── tests ├── __init__.py ├── test_config.py ├── test_generate.py ├── test_metrics.py └── test_pipelines.py /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | name: unittest 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.11"] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 27 | pip install . 28 | pip install -r requirements-test.txt 29 | - name: Lint with flake8 30 | run: | 31 | pip install flake8 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude experiments 34 | - name: Test 35 | run: SKIP_SLOW_TESTS=True python -m unittest discover -s tests 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mbr 🔥 2 | [![Main](https://github.com/ZurichNLP/mbr/workflows/unittest/badge.svg)](https://github.com/ZurichNLP/mbr/actions/workflows/unittest.yml) 3 | [![PyPI](https://img.shields.io/pypi/v/mbr)](https://pypi.python.org/pypi/mbr/) 4 | 5 | **mbr** adds Sampling-based Minimum Bayes Risk decoding to [Hugging Face transformers](https://github.com/huggingface/transformers). Originally proposed by [Eikema & Aziz (2022)](https://aclanthology.org/2022.emnlp-main.754/), this technique is a risk-minimizing algorithm for generating text with a language model. This repository implements several optimizations for MBR decoding. Most notably, **mbr** introduces reference aggregation [Vamvas & Sennrich (2024)](https://arxiv.org/abs/2402.04251). 6 | 7 | Pronounce: _ember_ /ˈɛm.bɚ/ 8 | 9 | ## Installation 10 | 11 | ```bash 12 | pip install mbr 13 | ``` 14 | 15 | Requirements: 16 | - Python >= 3.9 17 | - PyTorch 18 | - Hugging Face transformers < 4.39 19 | 20 | ## Usage 21 | The main components of **mbr** are: 22 | - `mbr.MBRGenerationMixin`: overrides a model's `generate` method to add MBR decoding. 23 | - `mbr.MBRGenerationConfig`: specifies the parameters of MBR decoding, e.g., the number of samples to generate and the metric to optimize. 24 | 25 | ### 1. Load a Hugging Face transformers model 26 | Models need to inherit from `MBRGenerationMixin` for MBR decoding to work. Here's two ways to achieve this, using the Llama model as an example: 27 | 28 | **Variant A:** 29 | 30 | ```python 31 | from transformers import LlamaForCausalLM 32 | 33 | from mbr import MBRGenerationMixin 34 | 35 | class MBRLlamaForCausalLM(MBRGenerationMixin, LlamaForCausalLM): 36 | pass 37 | ``` 38 | 39 | Then, you can use `MBRLlamaForCausalLM` as you would use `LlamaForCausalLM`: 40 | 41 | ```python 42 | model = MBRLlamaForCausalLM.from_pretrained(...) 43 | ``` 44 | 45 | **Variant B:** 46 | 47 | ```python 48 | from mbr import MBR 49 | model = MBR(LlamaForCausalLM).from_pretrained(...) 50 | ``` 51 | 52 | ### 2. Configure MBR decoding 53 | 54 | Create an `MBRConfig` object to pass to the model's `generate` method: 55 | 56 | ```python 57 | from mbr import MBRConfig 58 | 59 | mbr_config = MBRConfig( 60 | num_samples=10, 61 | metric="chrf", 62 | ) 63 | ``` 64 | 65 | ### 3. Generate text as usual 66 | Call the model's `generate` method directly, or use the Pipeline API. Make sure to pass the `mbr_config`, as well as the model's tokenizer. 67 | 68 | ```python 69 | from transformers import pipeline 70 | 71 | generator = pipeline("text-generation", model=model, tokenizer=tokenizer) 72 | output = generator("Hello,", mbr_config=mbr_config, tokenizer=tokenizer) 73 | ``` 74 | 75 | ## How MBR decoding works 76 | The following research papers, among many others, provide a description of Sampling-based Minimum Bayes Risk decoding: 77 | - [Sampling-Based Approximations to Minimum Bayes Risk Decoding for Neural Machine Translation](https://aclanthology.org/2022.emnlp-main.754) (Eikema & Aziz, EMNLP 2022) 78 | - [Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation](https://aclanthology.org/2021.acl-long.22) (Müller & Sennrich, ACL-IJCNLP 2021) 79 | 80 | In practice, MBR decoding is most commonly implemented as follows (on the example of machine translation): 81 | 82 | - Instead of searching for the single most probable output sequence (e.g., using beam search), generate a number of samples. 83 | - Score each sample against the other samples using a metric (e.g., BLEU). 84 | - Return the sample with the highest score. Intuitively, this can be seen as returning the median of all samples. 85 | 86 | Illustration of MBR decoding 87 | 88 | 89 | The terminology around MBR decoding varies: 90 | 91 | | Term used in this codebase | Alternative terms | 92 | |:---------------------------|:-----------------------------------------------------| 93 | | samples | candidates, hypotheses | 94 | | references | pseudo-references, evidence | 95 | | metric score | expected utility
(negative) expected risk, error | 96 | 97 | ## Details 98 | 99 | ### Configuring the sampling 100 | The generation of the samples can be customized by passing a `generation_config` to the `generate` method or to the pipeline call: 101 | 102 | ```python 103 | from transformers import GenerationConfig 104 | 105 | generation_config = GenerationConfig.from_pretrained("mymodel", 106 | do_sample=True, 107 | num_beams=1, 108 | epsilon_cutoff=0.02, 109 | ) 110 | model.generate(..., generation_config=generation_config) 111 | ``` 112 | 113 | ### Separate set of references 114 | By default, the samples themselves are used a references (or a subset of the samples if `num_references` is smaller than `num_samples`). 115 | 116 | You could also sample the reference set independently, using a custom generation config for the references: 117 | 118 | ```python 119 | from transformers import GenerationConfig 120 | 121 | references_config = GenerationConfig.from_pretrained("mymodel", 122 | do_sample=True, 123 | num_beams=1, 124 | top_p=0.9, 125 | ) 126 | model.generate(..., references_config=references_config) 127 | ``` 128 | 129 | ### Choosing a metric 130 | By default, **mbr** uses [fastChrF](https://github.com/jvamvas/fastChrF), which is optimized for efficient comparison of many samples to many references. 131 | 132 | You can also plug in metrics from the [**Hugging Face Evaluate**](https://github.com/huggingface/evaluate) library. 133 | 134 | A full list of metrics is found [here](https://huggingface.co/metrics). Some typical choices are: 135 | - [COMET](https://huggingface.co/spaces/evaluate-metric/comet) ([Rei et al., 2020](https://aclanthology.org/2020.emnlp-main.213/)) 136 | - [BLEURT](https://huggingface.co/spaces/evaluate-metric/bleurt) ([Sellam et al., 2020](https://aclanthology.org/2020.acl-main.704)) 137 | 138 | To use a metric from Hugging Face, either specify the metric's name (e.g., `"comet"`, `"bleurt"`) or pass an `evaluate.Metric` object directly. 139 | 140 | Since different metrics output differently structured dicts, you need to specify the `metric_output_field` that should be used as the metric score. 141 | 142 | ```python 143 | from evaluate import load 144 | 145 | metric = load('bleu') 146 | mbr_config = MBRGenerationConfig( 147 | metric=metric, 148 | metric_output_field="bleu", # the BLEU metric returns a dict with a "bleu" field 149 | ... 150 | ) 151 | ``` 152 | 153 | ### Customizing the metric computation 154 | Internally, **mbr** will call the metric's `compute` method to calculate the metric score for each sample. 155 | 156 | By default, **mbr** will call `compute` separately for each sample–reference pair. 157 | Since this requires many `compute` calls, it can make sense to optimize the metric computation. Different metrics will require different optimization strategies. 158 | To override the default way of calling the metric, define a `MetricRunner` class and pass it to the `generate` method: 159 | 160 | ```python 161 | from mbr import MetricRunner 162 | 163 | class MyMetricRunner(MetricRunner): 164 | 165 | def __call__(self, 166 | input_ids: torch.LongTensor, 167 | sample_ids: Tuple[torch.LongTensor], 168 | reference_ids: Tuple[torch.LongTensor], 169 | ) -> torch.FloatTensor: 170 | ... # TODO: implement your efficient metric computation here 171 | 172 | model.generate(..., metric_runner=MyMetricRunner()) 173 | ``` 174 | 175 | For **COMET**, an optimized implementation is already provided in `CometMetricRunner`: 176 | 177 | ```python 178 | from mbr.metrics.comet import CometMetricRunner 179 | 180 | mbr_config = MBRGenerationConfig( 181 | ..., 182 | metric="comet", 183 | metric_output_field="mean_score", 184 | ) 185 | 186 | metric_runner = CometMetricRunner(mbr_config, tokenizer) 187 | model.generate(..., metric_runner=metric_runner) 188 | ``` 189 | 190 | ### Optimizations 191 | MBR decoding is notoriously slow. **mbr** implements some optimizations: 192 | - Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling. 193 | - Optimized ChrF metric: [fastChrF](https://github.com/jvamvas/fastChrF) is used by default, which is a streamlined ChrF variant for MBR, implemented in Rust. 194 | - Cached metrics: Most metrics are computed only once for each unique sample–reference pair (since there will be duplicate samples and references). 195 | - Optimized COMET metric: Inspired by [Amrhein & Sennrich (2022)](https://aclanthology.org/2022.aacl-main.83/), `CometMetricRunner` caches sequence embeddings and reuses them for all pairwise comparisons. 196 | - Reference aggregation for COMET ([Vamvas & Sennrich, 2024](https://arxiv.org/abs/2402.04251)): Consider using `mbr.metrics.comet.AggregateCometMetricRunner` instead of the default `CometMetricRunner` if you have many references. 197 | 198 | ## Example scripts 199 | 200 | The [experiments](experiments) directory contains the code for reproductions of experiments from the following papers: 201 | 202 | - [MBR for (low-resource) machine translation](experiments/müller-sennrich-2021-understanding) ([Müller & Sennrich, 2021](https://aclanthology.org/2021.acl-long.22/)) 203 | - [MBR with neural metrics and epsilon sampling for machine translation](experiments/freitag-et-al-2023-epsilon) ([Freitag et al., 2023](https://arxiv.org/abs/2305.09860)) 204 | - [MBR for summarization](experiments/bertsch-et-al-2023-mbr) ([Bertsch et al., 2023](https://arxiv.org/abs/2310.01387)) 205 | 206 | ### Code for research papers 207 | - [Code for the paper "Linear-time Minimum Bayes Risk Decoding with Reference Aggregation" (Vamvas & Sennrich, 2024)](experiments/reference_aggregation) 208 | 209 | ## Related projects 210 | - https://github.com/roxot/mbr-nmt: Original implementation ([demo](https://colab.research.google.com/github/probabll/demo-mbr-nmt/blob/main/German-English.ipynb)) 211 | - https://github.com/ZurichNLP/understanding-mbr: MBR with Sockeye 212 | - https://github.com/ZurichNLP/mbr-sensitivity and https://github.com/Unbabel/COMET#minimum-bayes-risk-decoding: COMET metric for MBR 213 | - https://github.com/rainavyas/mbr_gec: MBR for Grammatical Error Correction 214 | 215 | ## Changelog 216 | 217 | - v0.3.0 (draft) 218 | - New feature: Reference Aggregation ([Vamvas & Sennrich, 2024](https://arxiv.org/abs/2402.04251)): 219 | - Set [fastChrF](https://github.com/jvamvas/fastChrF) with reference aggregation as default metric 220 | - Add `AggregateCometMetricRunner` to allow for reference aggregation with COMET 221 | - **Bugfix**: Disable dropout for COMET metric 222 | 223 | - v0.2.0 224 | - **Breaking change:** Rename `MBRGenerationConfig` to `MBRConfig` 225 | - **Breaking change:** `MetricRunner` now returns a `MetricOutput` dict instead of the raw tensor of scores. 226 | - Make the size of the metric cache configurable via `MBRConfig.metric_cache_size` 227 | - Allow that the number of references can be larger than the number of samples (if generated separately from the samples). 228 | - Remove `GenerationConfig` as parent class of `MBRConfig` 229 | 230 | ## Citation 231 | When using this code for research, please cite the following paper: 232 | 233 | ```bibtex 234 | @misc{vamvas-sennrich-2024-linear, 235 | title={Linear-time Minimum Bayes Risk Decoding with Reference Aggregation}, 236 | author={Jannis Vamvas and Rico Sennrich}, 237 | year={2024}, 238 | eprint={2402.04251}, 239 | archivePrefix={arXiv}, 240 | primaryClass={cs.CL} 241 | } 242 | ``` 243 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Experiments using the [**mbr**](https://github.com/ZurichNLP/mbr) package 3 | 4 | **Code for research papers:** 5 | - Linear-time Minimum Bayes Risk Decoding with Reference Aggregation (Vamvas & Sennrich, 2024) 6 | 7 | **Reproductions of experiments from related research papers:** 8 | 9 | - It's MBR All the Way Down: Modern Generation Techniques Through the Lens of Minimum Bayes Risk (Bertsch et al., 2023) 10 | - Epsilon Sampling Rocks: Investigating Sampling Strategies for Minimum Bayes Risk Decoding for Machine Translation (Freitag et al., 2023) 11 | - Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation (Müller & Sennrich, ACL-IJCNLP 2021) 12 | 13 | **Other experiments** 14 | - Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR. 15 | -------------------------------------------------------------------------------- /experiments/bertsch-et-al-2023-mbr/README.md: -------------------------------------------------------------------------------- 1 | This directory uses the [**mbr**](https://github.com/ZurichNLP/mbr) package to reproduce an experiment from the paper [It's MBR All the Way Down: Modern Generation Techniques Through the Lens of Minimum Bayes Risk](https://arxiv.org/abs/2310.01387) (Bertsch et al., 2023). 2 | 3 | ## Setup 4 | * Task: Summarization 5 | * Language: English 6 | * Model: facebook/bart-large-cnn ([Lewis et al., 2020](https://aclanthology.org/2020.acl-main.703/)) 7 | * MBR metric: ROUGE-1 ([Lin, 2004](https://aclanthology.org/W04-1013/)) 8 | * Number of samples: 30 9 | * Number of references: 30 10 | * Sampling approach: sampling with temperature 0.5 11 | * Reference sampling approach: sampling with temperature 1.0 12 | * Test set: CNN/DailyMail ([Nallapati et al., 2016](https://aclanthology.org/K16-1028/)) 13 | * Evaluation metric: ROUGE-1 14 | * Baselines: greedy decoding, beam search 15 | 16 | ## Results 17 | 18 | | Paper | Reproduction | 19 | |:---------------------------------------------------------------:|:---:| 20 | | ![CNN DM original](results/figures/CNN%20DM%20original.png) | ![CNN DM reproduction](results/figures/CNN%20DM%20reproduction.png) | 21 | -------------------------------------------------------------------------------- /experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM original.png -------------------------------------------------------------------------------- /experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM reproduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/bertsch-et-al-2023-mbr/results/figures/CNN DM reproduction.png -------------------------------------------------------------------------------- /experiments/bertsch-et-al-2023-mbr/run_experiment.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from pathlib import Path 3 | 4 | import evaluate 5 | import jsonlines 6 | import torch 7 | from datasets import load_dataset 8 | from tqdm import tqdm 9 | from transformers import BartForConditionalGeneration, AutoTokenizer, pipeline, GenerationConfig 10 | 11 | from mbr import MBR, MBRConfig 12 | 13 | results_file = jsonlines.open(Path(__file__).parent / f"results.jsonl", "w") 14 | 15 | model_name = "facebook/bart-large-cnn" 16 | model = MBR(BartForConditionalGeneration).from_pretrained(model_name) 17 | tokenizer = AutoTokenizer.from_pretrained(model_name) 18 | summarization_pipeline = pipeline( 19 | "summarization", 20 | model=model, 21 | tokenizer=tokenizer, 22 | device=(0 if torch.cuda.is_available() else -1), 23 | ) 24 | evaluation_metric_rouge = evaluate.load("rouge") 25 | 26 | dataset = load_dataset("cnn_dailymail", "3.0.0", split="test") 27 | 28 | # MBR 29 | mbr_config = MBRConfig() 30 | mbr_config.num_samples = 30 31 | mbr_config.num_references = 30 32 | mbr_config.metric = "rouge" 33 | mbr_config.metric_output_field = "rouge1" 34 | # efficiency settings 35 | mbr_config.metric_kwargs = {"rouge_types": ("rouge1",), "use_aggregator": False} 36 | 37 | generation_config = GenerationConfig.from_pretrained(model_name) 38 | generation_config.do_sample = True 39 | generation_config.num_beams = 1 40 | generation_config.temperature = 0.5 41 | generation_config.early_stopping = False 42 | generation_config.length_penalty = 1.0 43 | 44 | references_config = GenerationConfig.from_pretrained(model_name) 45 | references_config.do_sample = True 46 | references_config.num_beams = 1 47 | references_config.temperature = 1.0 48 | references_config.early_stopping = False 49 | references_config.length_penalty = 1.0 50 | 51 | summaries = [] 52 | outputs = summarization_pipeline( 53 | dataset["article"], 54 | mbr_config=mbr_config, 55 | generation_config=generation_config, 56 | references_config=references_config, 57 | tokenizer=tokenizer, 58 | truncation=True, 59 | progress_bar=True, 60 | batch_size=32, 61 | ) 62 | for output in outputs: 63 | summaries.append(output["summary_text"]) 64 | rouge_score = evaluation_metric_rouge.compute(predictions=summaries, references=dataset["highlights"]) 65 | results_file.write({ 66 | "method": "mbr rouge-1", 67 | "rouge": rouge_score, 68 | "summaries": summaries, 69 | }) 70 | 71 | # Baselines 72 | model = BartForConditionalGeneration.from_pretrained(model_name).to(summarization_pipeline.device) 73 | summarization_pipeline.model = model 74 | base_generation_config = GenerationConfig.from_pretrained(model_name) 75 | generation_configs = {} 76 | 77 | # greedy 78 | generation_config = deepcopy(base_generation_config) 79 | generation_config.do_sample = False 80 | generation_config.num_beams = 1 81 | generation_config.early_stopping = False 82 | generation_config.length_penalty = 1.0 83 | generation_configs["greedy"] = generation_config 84 | 85 | # beam search k=5 86 | generation_config = deepcopy(base_generation_config) 87 | generation_config.do_sample = False 88 | generation_config.num_beams = 5 89 | generation_configs["beam search k=5"] = generation_config 90 | 91 | # beam search k=10 92 | generation_config = deepcopy(base_generation_config) 93 | generation_config.do_sample = False 94 | generation_config.num_beams = 10 95 | generation_configs["beam search k=10"] = generation_config 96 | 97 | for method, generation_config in generation_configs.items(): 98 | print(method, flush=True) 99 | summaries = [] 100 | outputs = summarization_pipeline( 101 | dataset["article"], 102 | generation_config=generation_config, 103 | truncation=True, 104 | batch_size=1, 105 | ) 106 | for output in tqdm(outputs): 107 | summaries.append(output["summary_text"]) 108 | rouge_score = evaluation_metric_rouge.compute(predictions=summaries, references=dataset["highlights"]) 109 | results_file.write({ 110 | "method": method, 111 | "rouge": rouge_score, 112 | "summaries": summaries, 113 | }) 114 | 115 | results_file.close() 116 | -------------------------------------------------------------------------------- /experiments/chrf-vs-fastchrf/README.md: -------------------------------------------------------------------------------- 1 | Comparison of [fastChrF](https://github.com/jvamvas/fastChrF) to standard sentence-level ChrF ([Popović, 2015](https://aclanthology.org/W15-3049/)) as a metric for MBR. 2 | 3 | ## Setup 4 | * Task: Machine translation 5 | * Translation directions: en–de, de–en, en–ru, ru–en 6 | * Model: [facebook/wmt19-*](https://huggingface.co/facebook/wmt19-en-de) ([Ng et al., 2019](https://aclanthology.org/W19-5333/)). 7 | * MBR metrics: `fastchrf.pairwise_chrf` (a fast implementation of standard ChrF) and `fastchrf.aggregate_chrf` (a streamlined ChrF variant for MBR) 8 | * Number of samples: 256 9 | * Sampling approach: epsilon sampling with ε=0.02 10 | * Samples and references are the same 11 | * Test set: newstest2019 12 | * Evaluation metrics: chrF ([sacreBLEU](https://github.com/mjpost/sacrebleu)) and COMET-22 ([Rei et al., 2022](https://aclanthology.org/2022.wmt-1.52/)) 13 | * Baseline: beam search with beam size 4 14 | 15 | ## Results 16 | | Language Pair | Method | ChrF | COMET | duration (s) | 17 | |---------------|--------------------------------------|---------:|----------:|-------------:| 18 | | en-de | MBR with `fastchrf.pairwise_chrf` | 67.7 | 0.867 | 7798 | 19 | | en-de | MBR with `fastchrf.aggregate_chrf` | 67.7 | 0.867 | 7480 | 20 | | en-de | Beam search | 67.7 | 0.868 | 62 | 21 | | de-en | MBR with `fastchrf.pairwise_chrf` | 65.4 | 0.851 | 6894 | 22 | | de-en | MBR with `fastchrf.aggregate_chrf` | 65.6 | 0.850 | 6849 | 23 | | de-en | Beam search | 65.1 | 0.851 | 53 | 24 | | en-ru | MBR with `fastchrf.pairwise_chrf` | 57.5 | 0.862 | 7802 | 25 | | en-ru | MBR with `fastchrf.aggregate_chrf` | 57.5 | 0.862 | 7465 | 26 | | en-ru | Beam search | 56.9 | 0.863 | 64 | 27 | | ru-en | MBR with `fastchrf.pairwise_chrf` | 64.2 | 0.847 | 7541 | 28 | | ru-en | MBR with `fastchrf.aggregate_chrf` | 64.3 | 0.848 | 6689 | 29 | | ru-en | Beam search | 63.5 | 0.847 | 61 | 30 | | **Average** | **MBR with `fastchrf.pairwise_chrf`** | **63.7** | **0.857** | **7509** | 31 | | **Average** | **MBR with `fastchrf.aggregate_chrf`** | **63.7** | **0.857** | **7121** | 32 | | **Average** | **Beam search** | **63.3** | **0.857** | **60** | -------------------------------------------------------------------------------- /experiments/chrf-vs-fastchrf/run_experiment.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from copy import deepcopy 4 | from pathlib import Path 5 | 6 | import evaluate 7 | import jsonlines 8 | import sacrebleu 9 | import torch 10 | from datasets import load_dataset 11 | from tqdm import tqdm 12 | from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig 13 | 14 | from mbr import MBR, MBRConfig 15 | 16 | language_pair = sys.argv[1] 17 | assert language_pair in ["de-en", "en-de", "en-ru", "ru-en"] 18 | 19 | batch_size = 32 20 | 21 | results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w") 22 | 23 | model_name = f"facebook/wmt19-{language_pair}" 24 | model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name) 25 | tokenizer = AutoTokenizer.from_pretrained(model_name) 26 | mt_pipeline = pipeline( 27 | "translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1], 28 | model=model, 29 | tokenizer=tokenizer, 30 | device=(0 if torch.cuda.is_available() else -1), 31 | ) 32 | evaluation_metric_chrf = evaluate.load("chrf") 33 | evaluation_metric_comet = evaluate.load("comet", "Unbabel/wmt22-comet-da") 34 | 35 | src_path = sacrebleu.get_source_file("wmt19", language_pair) 36 | ref_path = sacrebleu.get_reference_files("wmt19", language_pair)[0] 37 | dataset = load_dataset("text", data_files={"test": src_path}) 38 | references = Path(ref_path).read_text().splitlines() 39 | assert len(dataset["test"]) == len(references) 40 | 41 | # MBR 42 | generation_config = GenerationConfig.from_pretrained(model_name) 43 | generation_config.do_sample = True 44 | generation_config.num_beams = 1 45 | generation_config.early_stopping = False 46 | generation_config.epsilon_cutoff = 0.02 47 | 48 | base_mbr_config = MBRConfig( 49 | num_samples=256, 50 | num_references=256, 51 | ) 52 | base_mbr_config.metric_cache_size = batch_size * base_mbr_config.num_samples * base_mbr_config.num_references 53 | mbr_configs = {} 54 | 55 | # MBR with fastchrf.pairwise_chrf 56 | mbr_config = deepcopy(base_mbr_config) 57 | mbr_config.metric = "fastchrf-pairwise" 58 | mbr_configs["MBR with fastchrf.pairwise_chrf"] = mbr_config 59 | 60 | # MBR with fastchrf.aggregate_chrf 61 | mbr_config = deepcopy(base_mbr_config) 62 | mbr_config.metric = "fastchrf-aggregate" 63 | mbr_configs["MBR with fastchrf.aggregate_chrf"] = mbr_config 64 | 65 | for method, mbr_config in mbr_configs.items(): 66 | 67 | set_seed(42) 68 | time_start = time.time() 69 | outputs = mt_pipeline( 70 | dataset["test"]["text"], 71 | mbr_config=mbr_config, 72 | generation_config=generation_config, 73 | tokenizer=tokenizer, 74 | batch_size=batch_size, 75 | progress_bar=True 76 | ) 77 | translations = [] 78 | for batch in tqdm(outputs): 79 | if isinstance(batch, dict): 80 | batch = [batch] 81 | translations += [translation["translation_text"] for translation in batch] 82 | time_end = time.time() 83 | 84 | chrf_score = evaluation_metric_chrf.compute( 85 | predictions=translations, 86 | references=references, 87 | ) 88 | comet_score = evaluation_metric_comet.compute( 89 | predictions=translations, 90 | references=references, 91 | sources=dataset["test"]["text"], 92 | gpus=0, 93 | ) 94 | results_file.write({ 95 | "language_pair": language_pair, 96 | "method": method, 97 | "chrf": chrf_score["score"], 98 | "comet22": comet_score["mean_score"], 99 | "duration": time_end - time_start, 100 | "translations": translations, 101 | }) 102 | 103 | # Beam search 104 | model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device) 105 | mt_pipeline.model = model 106 | generation_config = GenerationConfig.from_pretrained(model_name) 107 | generation_config.num_beams = 4 108 | 109 | set_seed(42) 110 | time_start = time.time() 111 | outputs = mt_pipeline( 112 | dataset["test"]["text"], 113 | generation_config=generation_config, 114 | batch_size=batch_size, 115 | ) 116 | translations = [] 117 | for batch in tqdm(outputs): 118 | if isinstance(batch, dict): 119 | batch = [batch] 120 | translations += [translation["translation_text"] for translation in batch] 121 | time_end = time.time() 122 | 123 | chrf_score = evaluation_metric_chrf.compute( 124 | predictions=translations, 125 | references=references, 126 | ) 127 | comet_score = evaluation_metric_comet.compute( 128 | predictions=translations, 129 | references=references, 130 | sources=dataset["test"]["text"], 131 | gpus=0, 132 | ) 133 | results_file.write({ 134 | "language_pair": language_pair, 135 | "method": f"beam search (beam size {generation_config.num_beams})", 136 | "chrf": chrf_score["score"], 137 | "comet22": comet_score["mean_score"], 138 | "duration": time_end - time_start, 139 | "translations": translations, 140 | }) 141 | 142 | results_file.close() 143 | -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/README.md: -------------------------------------------------------------------------------- 1 | This directory uses the [**mbr**](https://github.com/ZurichNLP/mbr) package to reproduce an experiment from the paper [Epsilon Sampling Rocks: Investigating Sampling Strategies for Minimum Bayes Risk Decoding for Machine Translation](https://arxiv.org/abs/2305.09860) (Freitag et al., 2023). 2 | 3 | ## Setup 4 | * Task: Machine translation 5 | * Translation directions: en–de, de–en 6 | * MBR metric: Neural metric (paper: BLEURT, this reproduction: Cometinho) 7 | * Number of samples: 1024 8 | * Various sampling approaches 9 | * Samples and references are the same 10 | * Test set: newstest2021 11 | * Evaluation metric: COMET ([Rei et al., 2020](https://aclanthology.org/2020.emnlp-main.213/)) 12 | * Baseline: beam search with beam size 4 13 | 14 | ## Differences to the paper 15 | * The paper used custom models trained without label smoothing, this reproduction uses an open-source model ([Ng et al., WMT 2019](https://aclanthology.org/W19-5333/)). 16 | * The paper used BLEURT ([Sellam et al., 2020](https://aclanthology.org/2020.acl-main.704/)) as a metric, this reproduction uses Cometinho ([Rei et al., 2022](https://aclanthology.org/2022.eamt-1.9/)). 17 | 18 | ## Results 19 | 20 | Comparison between ancestral sampling and epsilon sampling: 21 | 22 | | Paper | Reproduction | 23 | |:---------------------------------------------------------------:|:---:| 24 | | ![Main Comparison EN–DE (original)](results/figures/Main%20Comparison%20EN–DE%20(original).png) | ![Main Comparison EN–DE (reproduction)](results/figures/Main%20Comparison%20EN–DE%20(reproduction).png) | 25 | | ![Main Comparison DE–EN (original)](results/figures/Main%20Comparison%20DE–EN%20(original).png) | ![Main Comparison DE–EN (reproduction)](results/figures/Main%20Comparison%20DE–EN%20(reproduction).png) | 26 | 27 | Comparison between beam search and various sampling approaches: 28 | 29 | | Paper | Reproduction | 30 | |:---------------------------------------------------------------:|:---:| 31 | | ![All Results EN–DE (original)](results/figures/All%20Results%20EN–DE%20(original).png) | ![All Results EN–DE (reproduction)](results/figures/All%20Results%20EN–DE%20(reproduction).png) | 32 | | ![All Results DE–EN (original)](results/figures/All%20Results%20DE–EN%20(original).png) | ![All Results DE–EN (reproduction)](results/figures/All%20Results%20DE–EN%20(reproduction).png) | 33 | 34 | Although the models used in this reproduction seem to be less suitable for sampling, especially at higher temperatures, the main comparison between ancestral sampling and epsilon sampling has the same trend as in the paper. 35 | -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (original).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results DE–EN (reproduction).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (original).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/All Results EN–DE (reproduction).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (original).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison DE–EN (reproduction).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (original).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/freitag-et-al-2023-epsilon/results/figures/Main Comparison EN–DE (reproduction).png -------------------------------------------------------------------------------- /experiments/freitag-et-al-2023-epsilon/run_experiment.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from copy import deepcopy 3 | from pathlib import Path 4 | 5 | import evaluate 6 | import jsonlines 7 | import sacrebleu 8 | import torch 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | from transformers import FSMTForConditionalGeneration, AutoTokenizer, pipeline, set_seed, GenerationConfig 12 | 13 | from mbr import MBR, MBRConfig 14 | from mbr.metrics.comet import CometMetricRunner 15 | 16 | set_seed(42) 17 | 18 | language_pair = sys.argv[1] 19 | 20 | models = { 21 | "en-de": "facebook/wmt19-en-de", 22 | "en-zh": None, 23 | "de-en": "facebook/wmt19-de-en", 24 | "zh-en": None, 25 | } 26 | 27 | testsets = { 28 | "en-de": "wmt21/C", 29 | "en-zh": "wmt21/B", 30 | "de-en": "wmt21", 31 | "zh-en": "wmt21", 32 | } 33 | 34 | results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}3.jsonl", "w") 35 | 36 | model_name = models[language_pair] 37 | model = MBR(FSMTForConditionalGeneration).from_pretrained(model_name) 38 | tokenizer = AutoTokenizer.from_pretrained(model_name) 39 | mt_pipeline = pipeline( 40 | "translation_" + language_pair.split("-")[0] + "_to_" + language_pair.split("-")[1], 41 | model=model, 42 | tokenizer=tokenizer, 43 | device=(0 if torch.cuda.is_available() else -1), 44 | ) 45 | evaluation_metric_comet = evaluate.load("comet", "wmt20-comet-da") 46 | 47 | src_path = sacrebleu.get_source_file(testsets[language_pair], language_pair) 48 | ref_path = sacrebleu.get_reference_files(testsets[language_pair], language_pair)[0] 49 | dataset = load_dataset("text", data_files={"test": src_path}) 50 | references = Path(ref_path).read_text().splitlines() 51 | assert len(dataset["test"]) == len(references) 52 | 53 | # MBR 54 | mbr_config = MBRConfig() 55 | mbr_config.num_samples = 1024 56 | mbr_config.metric = "comet" 57 | mbr_config.metric_config_name = "eamt22-cometinho-da" 58 | mbr_config.metric_output_field = "mean_score" 59 | 60 | metric_runner = CometMetricRunner( 61 | mbr_config, 62 | tokenizer, 63 | device=mt_pipeline.device, 64 | batch_size_embed=64, 65 | batch_size_estimate=64, 66 | progress_bar=True, 67 | ) 68 | 69 | base_generation_config = GenerationConfig.from_pretrained(model_name) 70 | base_generation_config.do_sample = True 71 | base_generation_config.num_beams = 1 72 | base_generation_config.early_stopping = False 73 | generation_configs = {} 74 | 75 | # # MBR – Ancestral (τ=1.0) 76 | # generation_config = deepcopy(base_generation_config) 77 | # generation_config.temperature = 1.0 78 | # generation_configs["mbr ancestral (τ=1.0)"] = generation_config 79 | # 80 | # # MBR – Top-k (k=10, τ=1.0) 81 | # generation_config = deepcopy(base_generation_config) 82 | # generation_config.top_k = 10 83 | # generation_config.temperature = 1.0 84 | # generation_configs["mbr top-k (k=10, τ=1.0)"] = generation_config 85 | 86 | # # MBR – Top-k (k=50, τ=1.0) 87 | # generation_config = deepcopy(base_generation_config) 88 | # generation_config.top_k = 50 89 | # generation_config.temperature = 1.0 90 | # generation_configs["mbr top-k (k=50, τ=1.0)"] = generation_config 91 | # 92 | # # MBR – Nucleus (p=0.9, τ=1.5) 93 | # generation_config = deepcopy(base_generation_config) 94 | # generation_config.top_p = 0.9 95 | # generation_config.temperature = 1.5 96 | # generation_configs["mbr nucleus (p=0.9, τ=1.5)"] = generation_config 97 | # 98 | # MBR – Epsilon (ε=0.02, τ=1.0) 99 | generation_config = deepcopy(base_generation_config) 100 | generation_config.epsilon_cutoff = 0.02 101 | generation_config.temperature = 1.0 102 | generation_configs["mbr epsilon (ε=0.02, τ=1.0)"] = generation_config 103 | 104 | # MBR – Epsilon (ε=0.02, τ=2.0) 105 | generation_config = deepcopy(base_generation_config) 106 | generation_config.epsilon_cutoff = 0.02 107 | generation_config.temperature = 2.0 108 | generation_configs["mbr epsilon (ε=0.02, τ=2.0)"] = generation_config 109 | 110 | for method, generation_config in generation_configs.items(): 111 | outputs = mt_pipeline( 112 | dataset["test"]["text"], 113 | mbr_config=mbr_config, 114 | generation_config=generation_config, 115 | tokenizer=tokenizer, 116 | metric_runner=metric_runner, 117 | batch_size=32, 118 | progress_bar=True 119 | ) 120 | translations = [] 121 | for batch in tqdm(outputs): 122 | if isinstance(batch, dict): 123 | batch = [batch] 124 | translations += [translation["translation_text"] for translation in batch] 125 | comet_score = evaluation_metric_comet.compute( 126 | predictions=translations, 127 | references=references, 128 | sources=dataset["test"]["text"], 129 | ) 130 | results_file.write({ 131 | "language_pair": language_pair, 132 | "method": method, 133 | "comet20": comet_score["mean_score"], 134 | "translations": translations, 135 | }) 136 | 137 | # Beam search 138 | model = FSMTForConditionalGeneration.from_pretrained(model_name).half().to(mt_pipeline.device) 139 | mt_pipeline.model = model 140 | generation_config = GenerationConfig.from_pretrained(model_name) 141 | generation_config.num_beams = 4 142 | 143 | outputs = mt_pipeline( 144 | dataset["test"]["text"], 145 | generation_config=generation_config, 146 | batch_size=32, 147 | ) 148 | translations = [] 149 | for batch in tqdm(outputs): 150 | if isinstance(batch, dict): 151 | batch = [batch] 152 | translations += [translation["translation_text"] for translation in batch] 153 | comet_score = evaluation_metric_comet.compute( 154 | predictions=translations, 155 | references=references, 156 | sources=dataset["test"]["text"], 157 | ) 158 | results_file.write({ 159 | "language_pair": language_pair, 160 | "method": "beam search", 161 | "comet20": comet_score["mean_score"], 162 | "translations": translations, 163 | }) 164 | 165 | results_file.close() 166 | -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/README.md: -------------------------------------------------------------------------------- 1 | This directory uses the [**mbr**](https://github.com/ZurichNLP/mbr) package to reproduce an experiment from the paper [Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation](https://aclanthology.org/2021.acl-long.22) (Müller & Sennrich, ACL-IJCNLP 2021). 2 | 3 | ## Setup 4 | * Task: Machine translation 5 | * Translation directions: dan–epo, aze–eng, bel–rus, deu–fra 6 | * MBR metric: ChrF2 ([Popović, 2015](https://aclanthology.org/W15-3049/)) 7 | * Number of samples: 5–100 8 | * Sampling approach: ancestral sampling 9 | * Samples and references are the same 10 | * Test set: Tatoeba ([Tiedemann, 2020](https://aclanthology.org/2020.wmt-1.139/)) 11 | * Evaluation metric: ChrF2 12 | * Baseline: beam search with beam size 5 13 | 14 | ## Differences to the paper 15 | * The paper used custom models trained without label smoothing, this reproduction uses open-source models from Opus-MT ([Tiedemann & Thottingal, 2020](https://aclanthology.org/2020.eamt-1.61)). 16 | * The paper reports averages over 2 runs, this reproduction uses a single run. 17 | 18 | ## Results 19 | 20 | | Paper | Reproduction | 21 | |:---------------------------------------------------------------:|:---:| 22 | | ![AZE–ENG (original)](results/figures/AZE–ENG%20(original).png) | ![AZE–ENG (reproduction)](results/figures/AZE–ENG%20(reproduction).png) | 23 | | ![BEL–RUS (original)](results/figures/BEL–RUS%20(original).png) | ![BEL–RUS (reproduction)](results/figures/BEL–RUS%20(reproduction).png) | 24 | | ![DAN–EPO (original)](results/figures/DAN–EPO%20(original).png) | ![DAN–EPO (reproduction)](results/figures/DAN–EPO%20(reproduction).png) | 25 | | ![DEU–FRA (original)](results/figures/DEU–FRA%20(original).png) | ![DEU–FRA (reproduction)](results/figures/DEU–FRA%20(reproduction).png) | -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (original).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/AZE–ENG (reproduction).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (original).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/BEL–RUS (reproduction).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (original).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DAN–EPO (reproduction).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (original).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (original).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (reproduction).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/müller-sennrich-2021-understanding/results/figures/DEU–FRA (reproduction).png -------------------------------------------------------------------------------- /experiments/müller-sennrich-2021-understanding/run_experiment.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import jsonlines 5 | import sacrebleu 6 | import torch 7 | from datasets import load_dataset 8 | from tqdm import tqdm 9 | from transformers import MarianMTModel, AutoTokenizer, pipeline, set_seed 10 | from transformers.pipelines.base import KeyDataset 11 | 12 | from mbr import MBR, MBRConfig 13 | 14 | set_seed(42) 15 | 16 | language_pair = sys.argv[1] 17 | 18 | opus_mt_models = { 19 | "dan-epo": "Helsinki-NLP/opus-mt-da-eo", 20 | "aze-eng": "Helsinki-NLP/opus-mt-az-en", 21 | "bel-rus": "Helsinki-NLP/opus-mt-tc-big-zle-zle", 22 | "deu-fra": "Helsinki-NLP/opus-mt-de-fr", 23 | } 24 | 25 | language_codes = { 26 | "dan": "da", 27 | "epo": "eo", 28 | "aze": "az", 29 | "eng": "en", 30 | "bel": "be", 31 | "rus": "ru", 32 | "deu": "de", 33 | "fra": "fr", 34 | } 35 | 36 | results_file = jsonlines.open(Path(__file__).parent / f"results_{language_pair}.jsonl", "w") 37 | 38 | model_name = opus_mt_models[language_pair] 39 | model = MBR(MarianMTModel).from_pretrained(model_name) 40 | model = model.half() 41 | tokenizer = AutoTokenizer.from_pretrained(model_name) 42 | src_code = language_codes[language_pair.split("-")[0]] 43 | tgt_code = language_codes[language_pair.split("-")[1]] 44 | mt_pipeline = pipeline( 45 | "translation_" + src_code + "_to_" + tgt_code, 46 | model=model, 47 | tokenizer=tokenizer, 48 | device=0 if torch.cuda.is_available() else -1, 49 | ) 50 | 51 | dataset = load_dataset("Helsinki-NLP/tatoeba_mt", language_pair=language_pair) 52 | references = dataset["test"]["targetString"] 53 | 54 | # MBR 55 | mbr_config = MBRConfig() 56 | mbr_config.metric = "chrf" 57 | mbr_config.metric_output_field = "score" 58 | batch_size = 64 if "-big-" in model_name else 256 59 | 60 | for num_samples in range(5, 101, 5): 61 | mbr_config.num_samples = num_samples 62 | mbr_config.num_references = num_samples 63 | 64 | outputs = mt_pipeline( 65 | KeyDataset(dataset["test"], "sourceString"), 66 | mbr_config=mbr_config, 67 | tokenizer=tokenizer, 68 | do_sample=True, 69 | num_beams=1, 70 | batch_size=batch_size, 71 | ) 72 | translations = [] 73 | for batch in tqdm(outputs, total=len(dataset["test"]) // batch_size): 74 | translations += [translation["translation_text"] for translation in batch] 75 | chrf_score = sacrebleu.corpus_chrf(translations, [references]) 76 | results_file.write({ 77 | "language_pair": language_pair, 78 | "method": "mbr", 79 | "num_samples": num_samples, 80 | "chrf": chrf_score.score, 81 | "translations": translations, 82 | }) 83 | 84 | # Beam search 85 | model = MarianMTModel.from_pretrained(model_name).to(mt_pipeline.device) 86 | mt_pipeline.model = model 87 | 88 | outputs = mt_pipeline( 89 | KeyDataset(dataset["test"], "sourceString"), 90 | num_beams=5, 91 | batch_size=32, 92 | ) 93 | translations = [] 94 | for batch in tqdm(outputs): 95 | translations += [translation["translation_text"] for translation in batch] 96 | chrf_score = sacrebleu.corpus_chrf(translations, [references]) 97 | results_file.write({ 98 | "language_pair": language_pair, 99 | "method": "beam_search", 100 | "chrf": chrf_score.score, 101 | "translations": translations, 102 | }) 103 | 104 | results_file.close() 105 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/README.md: -------------------------------------------------------------------------------- 1 | ## Code for the paper ["Linear-time Minimum Bayes Risk Decoding with Reference Aggregation"](https://arxiv.org/abs/2402.04251) 2 | 3 | - The research code in this directory implements reference aggregation, an efficiency method for MBR that uses aggregate reference representations for faster utility estimation. 4 | - We apply reference aggregation to two metrics: ChrF and COMET. 5 | - Unlike the **mbr** package, the code in this directory is purely research-oriented (= reproducing the tables and figures in our paper) and not optimized for usability. 6 | 7 | ## Installation 8 | - Requires Python >= 3.9 and PyTorch. 9 | - `pip install -r requirements.txt` 10 | 11 | ## Reproducing the experiments 12 | 13 | ### Creating the samples 14 | - Warning: The following code downloads a large translation model from PyTorch Hub (if not already present) and generates 1024 samples per segment, which will take some time. 15 | - Samples will be stored in a JSON lines file in the directory `samples/`. 16 | ```bash 17 | python generate_samples.py --testset wmt21 --language-pair en-de --seed 0 18 | ``` 19 | 20 | ### Figure 1: Top-20 accuracy 21 | #### Generating the translations 22 | - Performing this analysis is computationally heavy because we run it for many different values of _s_ (x-axis of Figure 1). 23 | - We run N-by-N MBR, N-by-S MBR and Reference Aggregation in a single script, and all values of _s_, so that the embedding part of COMET only needs to run once. 24 | - The results are stored in a JSON lines file in the directory `validation_output/`. Each line describes the output for one method and one value of _s_. 25 | - In addition, the top translations will be stored in text files (one translation per line) in the `translations/` directory, to allow for easy evaluation. 26 | - The utility metric is either `"chrf"`, `"cometinho"` or `"comet22"`. 27 | ```bash 28 | python validation.py --testset wmt21 --language-pair en-de --seed 0 --utility comet22 --topk 20 29 | ``` 30 | #### Calculating accuracy 31 | - After the script has run, the series for Figure 1 (top-20 accuracy) can be printed as follows. 32 | - The method can be either `"n_by_s"` or `"aggregate"`. 33 | ```bash 34 | python plot_accuracy.py --testset wmt21 --language-pair en-de --seed 0 --utility comet22 --topk 20 --method aggregate 35 | ``` 36 | - To calculate top-1 accuracy instead: 37 | ```bash 38 | python plot_accuracy.py --testset wmt21 --language-pair en-de --seed 0 --utility comet22 --topk 20 --method aggregate --accuracy-topk 1 39 | ``` 40 | 41 | ### Table 1: Test results 42 | 43 | #### Generating the translations 44 | - In the test results table, we compare the translation quality of beam search, epsilon sampling, standard (pairwise) MBR, and reference aggregation. We also experiment with aggregate-to-fine MBR. 45 | - The following scripts create the translations and store them in the `translations/` directory. 46 | ```bash 47 | # Beam search 48 | python baseline_beam_search.py --language-pair en-de --testset wmt22 49 | 50 | # MBR with ChrF metric – standard MBR 51 | python run_mbr.py --method pairwise --testset wmt22 --language-pair en-de --seed 0 --utility chrf 52 | # MBR with ChrF metric – reference aggregation 53 | python run_mbr.py --method aggregate --testset wmt22 --language-pair en-de --seed 0 --utility chrf 54 | # MBR with ChrF metric – aggregate-to-fine MBR 55 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --utility chrf 56 | 57 | # MBR with Comethinho metric – standard MBR 58 | python run_mbr.py --method pairwise --testset wmt22 --language-pair en-de --seed 0 --utility cometinho 59 | # MBR with Cometinho metric – reference aggregation 60 | python run_mbr.py --method aggregate --testset wmt22 --language-pair en-de --seed 0 --utility cometinho 61 | # MBR with Cometinho metric – aggregate-to-fine MBR 62 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --utility cometinho 63 | 64 | # MBR with COMET-22 metric – standard MBR 65 | python run_mbr.py --method pairwise --testset wmt22 --language-pair en-de --seed 0 --utility comet22 66 | # MBR with COMET-22 metric – reference aggregation 67 | python run_mbr.py --method aggregate --testset wmt22 --language-pair en-de --seed 0 --utility comet22 68 | # MBR with COMET-22 metric – aggregate-to-fine MBR 69 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --utility comet22 70 | 71 | # Coarse-to-fine MBR: ChrF to COMET-22 72 | python run_mbr.py --method coarse_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --coarse-utility chrf --utility comet22 73 | # Aggregate-to-fine MBR: Aggregate ChrF to COMET-22 74 | python run_mbr.py --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair en-de --seed 0 --coarse-utility chrf --utility comet22 75 | ``` 76 | - For epsilon sampling, we simply read the JSON lines file created by `generate_samples.py` and extract the first sample for each segment. 77 | ```bash 78 | python baseline_epsilon_sampling.py --testset wmt22 --language-pair en-de --seed 0 79 | ``` 80 | 81 | #### Saving the source sequences and references in a text file 82 | - The sequences will be stored in text files in the `translations/` directory 83 | ```bash 84 | python scripts/save_src_and_ref.py --testset wmt22 --language-pair en-de 85 | ``` 86 | 87 | #### Evaluating the translations 88 | - Use a tool of your choice (e.g., https://github.com/mjpost/sacrebleu) to perform the evaluation. 89 | 90 | 91 | ## Citation 92 | ```bibtex 93 | @misc{vamvas-sennrich-2024-linear, 94 | title={Linear-time Minimum Bayes Risk Decoding with Reference Aggregation}, 95 | author={Jannis Vamvas and Rico Sennrich}, 96 | year={2024}, 97 | eprint={2402.04251}, 98 | archivePrefix={arXiv}, 99 | primaryClass={cs.CL} 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/baseline_beam_search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from experiments.reference_aggregation.experiment_utils import Testset 5 | from experiments.reference_aggregation.fairseq_utils import load_model 6 | 7 | 8 | def main(testset: str, language_pair: str, beam_size: int = 4, limit_segments: int = None, 9 | out_dir: Path = None) -> Path: 10 | if out_dir is None: 11 | out_dir = Path(__file__).parent 12 | 13 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments) 14 | 15 | model = load_model(language_pair) 16 | 17 | translations_dir = out_dir / "translations" 18 | translations_dir.mkdir(exist_ok=True) 19 | out_path = translations_dir / f"{dataset}.beam{beam_size}.{dataset.tgt_lang}" 20 | 21 | translations = model.translate(dataset.source_sentences, beam=beam_size) 22 | assert len(translations) == len(dataset.source_sentences) 23 | 24 | with open(out_path, "w") as f: 25 | for translation in translations: 26 | f.write(translation + "\n") 27 | 28 | return out_path 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 34 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 35 | parser.add_argument("--beam-size", type=int, default=4) 36 | parser.add_argument('--limit-segments', type=int, default=None, 37 | help='Limit number of segments that are processed (used for testing)') 38 | args = parser.parse_args() 39 | 40 | out_path = main(testset=args.testset, language_pair=args.language_pair, beam_size=args.beam_size, 41 | limit_segments=args.limit_segments, ) 42 | assert out_path.exists() 43 | print(f"Saved translations to {out_path}") 44 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/baseline_epsilon_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import jsonlines 5 | from tqdm import tqdm 6 | 7 | 8 | def main(testset: str, language_pair: str, num_samples: int, epsilon_cutoff: float, seed_no: int, 9 | out_dir: Path = None) -> Path: 10 | if out_dir is None: 11 | out_dir = Path(__file__).parent 12 | 13 | samples_dir = out_dir / "samples" 14 | assert samples_dir.exists() 15 | samples_path = samples_dir / f"samples.{testset}.{language_pair}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl" 16 | assert samples_path.exists() 17 | 18 | translations_dir = out_dir / "translations" 19 | translations_dir.mkdir(exist_ok=True) 20 | out_path = translations_dir / f"{testset}.{language_pair}.epsilon{epsilon_cutoff}.seed{seed_no}.{language_pair.split('-')[1]}" 21 | 22 | with jsonlines.open(samples_path) as f_in, open(out_path, "w") as f_out: 23 | for line in tqdm(f_in): 24 | samples = line["samples"] 25 | assert len(samples) == num_samples 26 | f_out.write(samples[0] + "\n") 27 | 28 | return out_path 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 34 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 35 | parser.add_argument('--seed', type=int, choices=range(10), required=True, 36 | help='Index of the random seed in the list of random seeds') 37 | parser.add_argument('--num-samples', type=int, default=1024) 38 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02) 39 | args = parser.parse_args() 40 | 41 | out_path = main(testset=args.testset, language_pair=args.language_pair, num_samples=args.num_samples, 42 | epsilon_cutoff=args.epsilon_cutoff, seed_no=args.seed, ) 43 | assert out_path.exists() 44 | print(f"Saved translations to {out_path}") 45 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/experiment_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import sacrebleu 6 | 7 | 8 | SEEDS = [ 9 | 553589, 10 | 456178, 11 | 817304, 12 | 6277, 13 | 792418, 14 | 707983, 15 | 249859, 16 | 272618, 17 | 760402, 18 | 472974, 19 | ] 20 | 21 | 22 | @dataclass 23 | class Testset: 24 | testset: str 25 | language_pair: str 26 | source_sentences: List[str] 27 | references: List[str] 28 | 29 | @property 30 | def src_lang(self): 31 | return self.language_pair.split("-")[0] 32 | 33 | @property 34 | def tgt_lang(self): 35 | return self.language_pair.split("-")[1] 36 | 37 | @classmethod 38 | def from_wmt(cls, wmt: str, language_pair: str, limit_segments: int = None): 39 | assert wmt in {"wmt21", "wmt22"} 40 | src_path = sacrebleu.get_source_file(wmt, language_pair) 41 | ref_path = sacrebleu.get_reference_files(wmt, language_pair)[0] 42 | source_sequences = Path(src_path).read_text().splitlines() 43 | references = Path(ref_path).read_text().splitlines() 44 | assert len(source_sequences) == len(references) 45 | if limit_segments is not None: 46 | source_sequences = source_sequences[:limit_segments] 47 | references = references[:limit_segments] 48 | return cls(testset=wmt, language_pair=language_pair, source_sentences=source_sequences, references=references, ) 49 | 50 | def __str__(self): 51 | return f"{self.testset}.{self.language_pair}" 52 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/fairseq_utils.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | from collections import namedtuple 3 | from pathlib import Path 4 | from typing import Union, List 5 | 6 | import torch 7 | from fairseq import hub_utils 8 | from fairseq.data.encoders.fastbpe import fastBPE 9 | from fairseq.hub_utils import GeneratorHubInterface 10 | 11 | 12 | class FairseqTranslationModel: 13 | """ 14 | Adapted from https://github.com/ZurichNLP/contrastive-conditioning/blob/master/translation_models/fairseq_models.py 15 | """ 16 | 17 | def __init__(self, 18 | name: str, 19 | model: GeneratorHubInterface = None, 20 | model_name_or_path: Union[Path, str] = None, 21 | checkpoint_file: str = "checkpoint_best.pt", 22 | src_bpe_codes: Union[Path, str] = None, 23 | tgt_bpe_codes: Union[Path, str] = None, 24 | **kwargs, 25 | ): 26 | self.name = name 27 | self.model = model or hub_utils.GeneratorHubInterface(**hub_utils.from_pretrained( 28 | model_name_or_path=str(model_name_or_path), 29 | checkpoint_file=checkpoint_file, 30 | **kwargs, 31 | )) 32 | # self.model.args.max_tokens = max_tokens 33 | self.model.eval() 34 | if torch.cuda.is_available(): 35 | self.model.cuda() 36 | 37 | # EN-RU systems use separate vocabularies, which is not yet supported by torch hub 38 | bpe_args = namedtuple("bpe_args", ["bpe_codes"]) 39 | if src_bpe_codes is not None: 40 | bpe_args_src = bpe_args(bpe_codes=str(src_bpe_codes)) 41 | self.src_bpe = fastBPE(bpe_args_src) 42 | else: 43 | self.src_bpe = None 44 | if tgt_bpe_codes is not None: 45 | bpe_args_tgt = bpe_args(bpe_codes=str(tgt_bpe_codes)) 46 | self.tgt_bpe = fastBPE(bpe_args_tgt) 47 | else: 48 | self.tgt_bpe = None 49 | 50 | def translate(self, sentences: List[str], beam: int = 5, **kwargs) -> List[str]: 51 | return self.model.translate(sentences, beam, **kwargs) 52 | 53 | def sample(self, sentences: List[str], seed=None, **kwargs) -> List[str]: 54 | return self.model.sample(sentences, sampling=True, seed=seed, **kwargs) 55 | 56 | def __str__(self): 57 | return self.name 58 | 59 | 60 | def load_model(language_pair: str) -> FairseqTranslationModel: 61 | if language_pair in ["de-en", "en-de"]: 62 | hub_interface = torch.hub.load( 63 | repo_or_dir="jvamvas/fairseq:epsilon", 64 | model=f'transformer.wmt19.{language_pair}.single_model', 65 | tokenizer='moses', 66 | bpe='fastbpe', 67 | ) 68 | model_name = f"transformer.wmt19.{language_pair}.single_model" 69 | model = FairseqTranslationModel( 70 | name=model_name, 71 | model=hub_interface, 72 | ) 73 | elif language_pair in ["en-ru", "ru-en"]: 74 | hub_interface = torch.hub.load( 75 | repo_or_dir="jvamvas/fairseq:epsilon", 76 | model=f'transformer.wmt19.{language_pair}.single_model', 77 | tokenizer='moses', 78 | bpe='fastbpe', 79 | ) 80 | 81 | # Need to download correct vocab separately (https://github.com/pytorch/fairseq/issues/2928) 82 | hub_base_dir = Path(torch.hub.get_dir()) 83 | correct_en_vocab_path = hub_base_dir / "en24k.fastbpe.code" 84 | correct_ru_vocab_path = hub_base_dir / "ru24k.fastbpe.code" 85 | if not correct_en_vocab_path.exists(): 86 | with urllib.request.urlopen("https://dl.fbaipublicfiles.com/fairseq/en24k.fastbpe.code") as response, \ 87 | open(correct_en_vocab_path, 'wb') as out_file: 88 | data = response.read() 89 | out_file.write(data) 90 | if not correct_ru_vocab_path.exists(): 91 | with urllib.request.urlopen("https://dl.fbaipublicfiles.com/fairseq/ru24k.fastbpe.code") as response, \ 92 | open(correct_ru_vocab_path, 'wb') as out_file: 93 | data = response.read() 94 | out_file.write(data) 95 | 96 | evaluator_name = f"transformer.wmt19.{language_pair}.single_model" 97 | model = FairseqTranslationModel( 98 | name=evaluator_name, 99 | model=hub_interface, 100 | src_bpe_codes=correct_en_vocab_path if language_pair == "en-ru" else correct_ru_vocab_path, 101 | tgt_bpe_codes=correct_ru_vocab_path if language_pair == "en-ru" else correct_en_vocab_path, 102 | ) 103 | else: 104 | raise NotImplementedError 105 | return model 106 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/generate_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import jsonlines 5 | from tqdm import tqdm 6 | 7 | from experiments.reference_aggregation.experiment_utils import SEEDS, Testset 8 | from experiments.reference_aggregation.fairseq_utils import load_model 9 | 10 | 11 | def main(testset: str, language_pair: str, seed_no: int, num_samples: int, epsilon_cutoff: float, 12 | limit_segments: int = None, out_dir: Path = None) -> Path: 13 | if out_dir is None: 14 | out_dir = Path(__file__).parent 15 | 16 | seed = SEEDS[seed_no] 17 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments) 18 | 19 | model = load_model(language_pair) 20 | 21 | samples_dir = out_dir / "samples" 22 | samples_dir.mkdir(exist_ok=True) 23 | out_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl" 24 | 25 | with jsonlines.open(out_path, "w") as f: 26 | for source_sentence in tqdm(dataset.source_sentences): 27 | f.write({"samples": model.sample(num_samples * [source_sentence], seed=seed, 28 | sampling_epsilon_cutoff=epsilon_cutoff), }) 29 | 30 | return out_path 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 36 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 37 | parser.add_argument('--seed', type=int, choices=range(10), required=True, 38 | help='Index of the random seed in the list of random seeds') 39 | parser.add_argument('--num-samples', type=int, default=1024) 40 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02) 41 | parser.add_argument('--limit-segments', type=int, default=None, 42 | help='Limit number of segments that are processed (used for testing)') 43 | args = parser.parse_args() 44 | 45 | out_path = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed, 46 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, limit_segments=args.limit_segments, ) 47 | assert out_path.exists() 48 | print(f"Saved samples to {out_path}") 49 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/mbr_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from collections import namedtuple 4 | from typing import Set, Dict, List, Tuple 5 | 6 | import evaluate 7 | import numpy as np 8 | import torch 9 | from fastchrf import pairwise_chrf, aggregate_chrf 10 | 11 | 12 | class ChrfUtility: 13 | 14 | def __init__(self, char_order: int = 6, beta: float = 2.0, remove_whitespace: bool = True, eps_smoothing: bool = False): 15 | self.char_order = char_order 16 | self.beta = beta 17 | self.remove_whitespace = remove_whitespace 18 | self.eps_smoothing = eps_smoothing 19 | 20 | def rank_samples_n_by_s(self, source_sequence: str, samples: List[str], references: List[str], s: int = None) -> np.ndarray: 21 | """ 22 | Returns the indices of the samples sorted by their utility score, in descending order. 23 | :param s: The number of references to subsample from the list of references (default: all references) 24 | """ 25 | if s is None: 26 | s = len(references) 27 | assert s <= len(references) 28 | references = references[:s] 29 | 30 | metric_scores = pairwise_chrf( 31 | [samples], [references], 32 | char_order=self.char_order, 33 | beta=self.beta, 34 | remove_whitespace=self.remove_whitespace, 35 | eps_smoothing=self.eps_smoothing, 36 | )[0] 37 | metric_scores = np.array(metric_scores) # num_samples x s 38 | 39 | # Sort the samples by their average score 40 | sample_scores = metric_scores.mean(axis=1) 41 | sample_indices = sample_scores.argsort()[::-1] 42 | return sample_indices 43 | 44 | def rank_samples_aggregate(self, source_sequence: str, samples: List[str], references: List[str], s: int) -> np.ndarray: 45 | """ 46 | Returns the indices of the samples sorted by their utility score, in descending order. 47 | :param s: The number of aggregate references 48 | """ 49 | assert s <= len(references) 50 | 51 | num_partitions = s 52 | partition_size = len(references) // num_partitions 53 | reference_partitions = [references[i * partition_size:(i + 1) * partition_size] for i in range(num_partitions)] 54 | 55 | metric_scores = aggregate_chrf( 56 | num_partitions * [samples], reference_partitions, 57 | char_order=self.char_order, 58 | beta=self.beta, 59 | remove_whitespace=self.remove_whitespace, 60 | eps_smoothing=self.eps_smoothing, 61 | ) 62 | metric_scores = np.array(metric_scores).transpose() # num_samples x s 63 | 64 | # Sort the samples by their average score 65 | sample_scores = metric_scores.mean(axis=1) 66 | sample_indices = sample_scores.argsort()[::-1] 67 | return sample_indices 68 | 69 | 70 | CometInputTriple = namedtuple("CometInputTriple", ["src", "hyp", "ref"]) 71 | 72 | 73 | class CometUtility: 74 | 75 | def __init__(self, 76 | model_name: str, 77 | batch_size_embed: int = 1, 78 | batch_size_estimate: int = 1, 79 | ): 80 | self.model_name = model_name 81 | self.batch_size_embed = batch_size_embed 82 | self.batch_size_estimate = batch_size_estimate 83 | self.scorer = evaluate.load("comet", model_name).scorer 84 | if torch.cuda.is_available(): 85 | self.scorer = self.scorer.to("cuda:0") 86 | else: 87 | logging.warning("CUDA not available, using CPU") 88 | self.scorer.eval() 89 | self.device = self.scorer.device 90 | self.embeddings: Dict[str, torch.FloatTensor] = {} 91 | 92 | @torch.no_grad() 93 | def compute_features(self, input_sequences: Set[str]): 94 | assert not self.scorer.training 95 | input_sequences = list(input_sequences) 96 | encodings = self.scorer.encoder.prepare_sample(input_sequences).to(self.device) 97 | batches = itertools.zip_longest(range(0, len(input_sequences), self.batch_size_embed), 98 | range(self.batch_size_embed, len(input_sequences), self.batch_size_embed)) 99 | for start_idx, end_idx in batches: 100 | embeddings = self.scorer.get_sentence_embedding( 101 | input_ids=encodings["input_ids"][start_idx:end_idx], 102 | attention_mask=encodings["attention_mask"][start_idx:end_idx], 103 | ) 104 | for j in range(start_idx, end_idx if end_idx is not None else len(input_sequences)): 105 | embedding = embeddings[j - start_idx] 106 | self.embeddings[input_sequences[j]] = embedding 107 | 108 | def clear_features(self): 109 | self.embeddings = {} 110 | 111 | @torch.no_grad() 112 | def rank_samples_n_by_s(self, source_sequence: str, samples: List[str], references: List[str], s: int = None) -> np.ndarray: 113 | """ 114 | Returns the indices of the samples sorted by their utility score, in descending order. 115 | :param s: The number of references to subsample from the list of references (default: all references) 116 | """ 117 | if s is None: 118 | s = len(references) 119 | assert s <= len(references) 120 | references = references[:s] 121 | assert not self.scorer.training 122 | 123 | # Collect all unique input triples 124 | input_triples: Set[Tuple[str, str, str]] = set() 125 | for sample in samples: 126 | for reference in references: 127 | input_triples.add(CometInputTriple(src=source_sequence, hyp=sample, ref=reference)) 128 | input_triples: List = list(input_triples) 129 | 130 | # Compute scores for all input triples 131 | triple_scores: Dict[CometInputTriple, torch.tensor] = {} 132 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate), 133 | range(self.batch_size_estimate, len(input_triples), self.batch_size_estimate)) 134 | for start_idx, end_idx in batches: 135 | batch = input_triples[start_idx:end_idx] 136 | batch_scores = self.scorer.estimate( 137 | src_sentemb=torch.stack([self.embeddings[input.src] for input in batch]), 138 | mt_sentemb=torch.stack([self.embeddings[input.hyp] for input in batch]), 139 | ref_sentemb=torch.stack([self.embeddings[input.ref] for input in batch]), 140 | ) 141 | for i in range(start_idx, end_idx if end_idx is not None else len(input_triples)): 142 | triple = batch[i - start_idx] 143 | score = batch_scores.score[i - start_idx] 144 | triple_scores[triple] = score 145 | 146 | # Fill in the metric scores matrix 147 | metric_scores = torch.zeros((len(samples), len(references))) 148 | for i, sample in enumerate(samples): 149 | for j, reference in enumerate(references): 150 | metric_scores[i, j] = triple_scores[CometInputTriple(src=source_sequence, hyp=sample, ref=reference)] 151 | 152 | # Sort the samples by their average score 153 | sample_scores = metric_scores.mean(dim=1) 154 | sample_indices = sample_scores.argsort(descending=True) 155 | return sample_indices.cpu().numpy() 156 | 157 | @torch.no_grad() 158 | def rank_samples_aggregate(self, source_sequence: str, samples: List[str], references: List[str], s: int) -> np.ndarray: 159 | """ 160 | Returns the indices of the samples sorted by their utility score, in descending order. 161 | :param s: The number of aggregate referencesq 162 | """ 163 | assert s <= len(references) 164 | assert not self.scorer.training 165 | 166 | num_partitions = s 167 | partition_size = len(references) // num_partitions 168 | 169 | # Add aggregate reference embeddings to the embeddings cache 170 | reference_embeddings = torch.stack([self.embeddings[reference] for reference in references]) 171 | avg_reference_embeddings = reference_embeddings.view(num_partitions, partition_size, -1).mean(dim=1) 172 | for partition_id in range(num_partitions): 173 | self.embeddings[f"aggregate_{partition_id}"] = avg_reference_embeddings[partition_id] 174 | 175 | # Collect all unique input triples 176 | input_triples: Set[Tuple[str, str, str]] = set() 177 | for sample in samples: 178 | for partition_id in range(s): 179 | input_triples.add(CometInputTriple(src=source_sequence, hyp=sample, ref=f"aggregate_{partition_id}")) 180 | input_triples: List = list(input_triples) 181 | 182 | # Compute scores for all input triples 183 | triple_scores: Dict[CometInputTriple, torch.tensor] = {} 184 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate), 185 | range(self.batch_size_estimate, len(input_triples), self.batch_size_estimate)) 186 | for start_idx, end_idx in batches: 187 | batch = input_triples[start_idx:end_idx] 188 | batch_scores = self.scorer.estimate( 189 | src_sentemb=torch.stack([self.embeddings[input.src] for input in batch]), 190 | mt_sentemb=torch.stack([self.embeddings[input.hyp] for input in batch]), 191 | ref_sentemb=torch.stack([self.embeddings[input.ref] for input in batch]), 192 | ) 193 | for i in range(start_idx, end_idx if end_idx is not None else len(input_triples)): 194 | triple = batch[i - start_idx] 195 | score = batch_scores.score[i - start_idx] 196 | triple_scores[triple] = score 197 | 198 | # Fill in the metric scores matrix 199 | metric_scores = torch.zeros((len(samples), num_partitions)) 200 | for i, sample in enumerate(samples): 201 | for partition_id in range(s): 202 | metric_scores[i, partition_id] = triple_scores[CometInputTriple(src=source_sequence, hyp=sample, ref=f"aggregate_{partition_id}")] 203 | 204 | # Sort the samples by their average score 205 | sample_scores = metric_scores.mean(dim=1) 206 | sample_indices = sample_scores.argsort(descending=True) 207 | return sample_indices.cpu().numpy() 208 | 209 | 210 | def load_utility(utility_name: str): 211 | if utility_name == "chrf": 212 | return ChrfUtility() 213 | elif utility_name.startswith("comet22"): 214 | return CometUtility("Unbabel/wmt22-comet-da", batch_size_embed=128, batch_size_estimate=128) 215 | elif utility_name.startswith("cometinho"): 216 | return CometUtility("eamt22-cometinho-da", batch_size_embed=512, batch_size_estimate=512) 217 | else: 218 | raise ValueError(f"Unknown utility {utility_name}") 219 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/plot_accuracy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import List, Tuple 4 | 5 | import jsonlines 6 | 7 | from experiments.reference_aggregation.experiment_utils import Testset 8 | 9 | 10 | def main(testset: str, language_pair: str, seed_no: int, fine_utility_name: str, topk: int, accuracy_topk: int, 11 | method: str, num_samples: int = 1024, epsilon_cutoff: float = 0.02, coarse_utility_name: str = None, 12 | limit_segments: int = None, out_dir: Path = None) -> List[Tuple[int, float]]: 13 | """ 14 | Returns a series of (s, accuracy) tuples, starting with the highest s 15 | """ 16 | if out_dir is None: 17 | out_dir = Path(__file__).parent 18 | 19 | if coarse_utility_name is None: 20 | coarse_utility_name = fine_utility_name 21 | 22 | assert topk <= num_samples 23 | assert accuracy_topk <= topk 24 | 25 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments) 26 | 27 | samples_dir = out_dir / "samples" 28 | assert samples_dir.exists() 29 | samples_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl" 30 | assert samples_path.exists() 31 | with jsonlines.open(samples_path) as f: 32 | samples = [line["samples"] for line in f] 33 | samples = [sample[:num_samples] for sample in samples] 34 | if limit_segments is not None: 35 | samples = samples[:limit_segments] 36 | 37 | output_dir = out_dir / "validation_output" 38 | assert output_dir.exists() 39 | fine_output_path = output_dir / f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{fine_utility_name}.top{topk}.jsonl" 40 | with jsonlines.open(fine_output_path) as f: 41 | fine_data = list(f) 42 | coarse_output_path = output_dir / f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{coarse_utility_name}.top{topk}.jsonl" 43 | with jsonlines.open(coarse_output_path) as f: 44 | coarse_data = list(f) 45 | 46 | # Get n-by-n top-1 samples – should not matter which method 47 | n_by_n_lines = [line for line in fine_data if line["s"] == num_samples] 48 | assert len(n_by_n_lines) == 2 49 | for ranking in zip(n_by_n_lines[0]["rankings"], n_by_n_lines[1]["rankings"]): 50 | assert ranking[0] == ranking[1] 51 | n_by_n_rankings = n_by_n_lines[0]["rankings"] 52 | n_by_n_top1_samples = [samples[i][n_by_n_rankings[i][0]].strip() for i in range(len(samples))] 53 | 54 | # Get top-k accuracies for efficiency method 55 | method_lines = [line for line in coarse_data if line["method"] == method] 56 | assert len(method_lines) == len(coarse_data) / 2 57 | s_values = list(sorted([line["s"] for line in method_lines], reverse=True)) 58 | accuracies = [] # for each s 59 | for s in s_values: 60 | s_lines = [line for line in method_lines if line["s"] == s] 61 | assert len(s_lines) == 1 62 | s_rankings = s_lines[0]["rankings"] 63 | s_topk_samples = [{samples[i][ranking].strip() for ranking in s_rankings[i][:accuracy_topk]} for i in 64 | range(len(samples))] 65 | s_num_correct = sum([1 if n_by_n_top1_samples[i] in s_topk_samples[i] else 0 for i in range(len(samples))]) 66 | s_accuracy = s_num_correct / len(samples) 67 | accuracies.append(s_accuracy) 68 | 69 | # Format: (1,-0.4)(2,-0.6)(4,-0.5)(8,0.1)(16,0.1)(32,0.2)(64,0.1)(128,-0.0)(256,-0.0) 70 | series = [(s, accuracy) for s, accuracy in zip(s_values, accuracies)] 71 | return series 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 77 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 78 | parser.add_argument('--seed', type=int, choices=range(10), required=True, 79 | help='Index of the random seed in the list of random seeds') 80 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True) 81 | parser.add_argument('--coarse-utility', choices=['chrf', 'cometinho', 'comet22'], default=None, 82 | help='Utility used for coarse-grained method (default: same as fine-grained)') 83 | parser.add_argument('--topk', type=int, default=20, 84 | help='Number of top translations that have been saved in the jsonl file') 85 | parser.add_argument('--method', choices=['n_by_s', 'aggregate'], required=True) 86 | parser.add_argument('--num-samples', type=int, default=1024) 87 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02) 88 | parser.add_argument('--accuracy-topk', type=int, default=None, 89 | help='Number of top translations that are used to compute the accuracy (default: same as data-topk)') 90 | parser.add_argument('--limit-segments', type=int, default=None, 91 | help='Limit number of segments that are processed (used for testing)') 92 | args = parser.parse_args() 93 | 94 | if args.coarse_utility is None: 95 | args.coarse_utility = args.utility 96 | if args.accuracy_topk is None: 97 | args.accuracy_topk = args.topk 98 | 99 | series = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed, 100 | fine_utility_name=args.utility, coarse_utility_name=args.coarse_utility, topk=args.topk, method=args.method, 101 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, accuracy_topk=args.accuracy_topk, 102 | limit_segments=args.limit_segments, ) 103 | 104 | # Format: (1,-0.4)(2,-0.6)(4,-0.5)(8,0.1)(16,0.1)(32,0.2)(64,0.1)(128,-0.0)(256,-0.0) 105 | series_str = "".join([f"({s},{accuracy:.5f})" for s, accuracy in series]) 106 | print( 107 | f"Testset: {args.testset}, language pair: {args.language_pair}, seed: {args.seed}, fine utility: {args.utility}, coarse utility: {args.coarse_utility}, topk: {args.topk}, method: {args.method}") 108 | print(f"Top-{args.accuracy_topk} accuracy:") 109 | print(series_str) 110 | print() 111 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/jvamvas/fairseq.git@epsilon # Adds epsilon sampling 2 | scikit_learn==1.3.2 3 | sacremoses==0.1.1 4 | fastBPE==0.1.0 5 | requests 6 | jsonlines 7 | unbabel-comet==2.1.1 8 | evaluate==0.4.1 9 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/results/chrf.log: -------------------------------------------------------------------------------- 1 | en-de 2 | baselines 3 | 61.4 4 | 55.1 5 | chrf 6 | 62.0 7 | 62.0 8 | 62.0 9 | cometinho 10 | 59.8 11 | 60.2 12 | 59.8 13 | comet22 14 | 60.0 15 | 60.3 16 | 60.0 17 | coarse-to-fine 18 | 61.8 19 | 61.8 20 | 21 | de-en 22 | baselines 23 | 56.3 24 | 50.7 25 | chrf 26 | 57.5 27 | 57.6 28 | 57.6 29 | cometinho 30 | 55.5 31 | 55.5 32 | 55.5 33 | comet22 34 | 55.3 35 | 55.6 36 | 55.4 37 | coarse-to-fine 38 | 57.0 39 | 57.1 40 | 41 | en-ru 42 | baselines 43 | 52.3 44 | 46.8 45 | chrf 46 | 54.0 47 | 53.9 48 | 54.0 49 | cometinho 50 | 51.4 51 | 51.9 52 | 51.4 53 | comet22 54 | 51.1 55 | 51.7 56 | 51.3 57 | coarse-to-fine 58 | 53.6 59 | 53.6 60 | 61 | ru-en 62 | baselines 63 | 64.4 64 | 57.6 65 | chrf 66 | 65.5 67 | 65.5 68 | 65.5 69 | cometinho 70 | 63.4 71 | 63.4 72 | 63.4 73 | comet22 74 | 62.7 75 | 63.3 76 | 62.8 77 | coarse-to-fine 78 | 64.9 79 | 64.9 80 | 81 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/results/comet-xl.log: -------------------------------------------------------------------------------- 1 | en-de 2 | baselines 3 | wmt22.en-de.beam4.de score: 0.9599 4 | wmt22.en-de.epsilon0.02.seed0.de score: 0.9426 5 | chrf 6 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.chrf.de score: 0.9570 7 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.chrf.de score: 0.9569 8 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.de score: 0.9571 9 | cometinho 10 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.cometinho.de score: 0.9599 11 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.cometinho.de score: 0.9592 12 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.de score: 0.9599 13 | comet22 14 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.comet22.de score: 0.9656 15 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.comet22.de score: 0.9643 16 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.de score: 0.9656 17 | coarse-to-fine 18 | mbr.wmt22.en-de.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.9626 19 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.9625 20 | 21 | de-en 22 | baselines 23 | wmt22.de-en.beam4.en score: 0.9411 24 | wmt22.de-en.epsilon0.02.seed0.en score: 0.9163 25 | chrf 26 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.9397 27 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.9397 28 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.9397 29 | cometinho 30 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.9434 31 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.9420 32 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.9432 33 | comet22 34 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.9510 35 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.9491 36 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.9505 37 | coarse-to-fine 38 | mbr.wmt22.de-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9471 39 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9473 40 | 41 | en-ru 42 | baselines 43 | wmt22.en-ru.beam4.ru score: 0.8575 44 | wmt22.en-ru.epsilon0.02.seed0.ru score: 0.8157 45 | chrf 46 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.chrf.ru score: 0.8474 47 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.chrf.ru score: 0.8469 48 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.ru score: 0.8476 49 | cometinho 50 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8647 51 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8622 52 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8648 53 | comet22 54 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.comet22.ru score: 0.8914 55 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.comet22.ru score: 0.8867 56 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.ru score: 0.8915 57 | coarse-to-fine 58 | mbr.wmt22.en-ru.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8740 59 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8744 60 | 61 | ru-en 62 | baselines 63 | wmt22.ru-en.beam4.en score: 0.9290 64 | wmt22.ru-en.epsilon0.02.seed0.en score: 0.8996 65 | chrf 66 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.9269 67 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.9263 68 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.9270 69 | cometinho 70 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.9318 71 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.9307 72 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.9318 73 | comet22 74 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.9393 75 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.9374 76 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.9391 77 | coarse-to-fine 78 | mbr.wmt22.ru-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9351 79 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.9351 80 | 81 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/results/comet22.log: -------------------------------------------------------------------------------- 1 | en-de 2 | baselines 3 | wmt22.en-de.beam4.de score: 0.8589 4 | wmt22.en-de.epsilon0.02.seed0.de score: 0.8343 5 | chrf 6 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.chrf.de score: 0.8582 7 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.chrf.de score: 0.8586 8 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.de score: 0.8586 9 | cometinho 10 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.cometinho.de score: 0.8636 11 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.cometinho.de score: 0.8632 12 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.de score: 0.8636 13 | comet22 14 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.comet22.de score: 0.8840 15 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.comet22.de score: 0.8808 16 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.de score: 0.8829 17 | coarse-to-fine 18 | mbr.wmt22.en-de.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.8704 19 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.8709 20 | 21 | de-en 22 | baselines 23 | wmt22.de-en.beam4.en score: 0.8391 24 | wmt22.de-en.epsilon0.02.seed0.en score: 0.8159 25 | chrf 26 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.8410 27 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.8413 28 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.8412 29 | cometinho 30 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.8444 31 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.8436 32 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.8442 33 | comet22 34 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.8600 35 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.8572 36 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.8591 37 | coarse-to-fine 38 | mbr.wmt22.de-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8497 39 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8500 40 | 41 | en-ru 42 | baselines 43 | wmt22.en-ru.beam4.ru score: 0.8288 44 | wmt22.en-ru.epsilon0.02.seed0.ru score: 0.8072 45 | chrf 46 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.chrf.ru score: 0.8354 47 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.chrf.ru score: 0.8345 48 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.ru score: 0.8352 49 | cometinho 50 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8464 51 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8440 52 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.ru score: 0.8456 53 | comet22 54 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.comet22.ru score: 0.8782 55 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.comet22.ru score: 0.8739 56 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.ru score: 0.8773 57 | coarse-to-fine 58 | mbr.wmt22.en-ru.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8570 59 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.8573 60 | 61 | ru-en 62 | baselines 63 | wmt22.ru-en.beam4.en score: 0.8434 64 | wmt22.ru-en.epsilon0.02.seed0.en score: 0.8175 65 | chrf 66 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.8440 67 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.8437 68 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.8437 69 | cometinho 70 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.8488 71 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.8482 72 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.8487 73 | comet22 74 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.8623 75 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.8601 76 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.8617 77 | coarse-to-fine 78 | mbr.wmt22.ru-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8529 79 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.8526 80 | 81 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/results/cometinho.log: -------------------------------------------------------------------------------- 1 | en-de 2 | baselines 3 | wmt22.en-de.beam4.de score: 0.5536 4 | wmt22.en-de.epsilon0.02.seed0.de score: 0.4571 5 | chrf 6 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.chrf.de score: 0.5560 7 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.chrf.de score: 0.5581 8 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.de score: 0.5584 9 | cometinho 10 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.cometinho.de score: 0.6131 11 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.cometinho.de score: 0.6110 12 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.de score: 0.6138 13 | comet22 14 | mbr.wmt22.en-de.pairwise.n1024.epsilon0.02.seed0.comet22.de score: 0.5791 15 | mbr.wmt22.en-de.aggregate.n1024.epsilon0.02.seed0.comet22.de score: 0.5783 16 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.de score: 0.5796 17 | coarse-to-fine 18 | mbr.wmt22.en-de.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.5709 19 | mbr.wmt22.en-de.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.de score: 0.5729 20 | 21 | de-en 22 | baselines 23 | wmt22.de-en.beam4.en score: 0.5280 24 | wmt22.de-en.epsilon0.02.seed0.en score: 0.4226 25 | chrf 26 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.5491 27 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.5500 28 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.5491 29 | cometinho 30 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.5933 31 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.5878 32 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.5916 33 | comet22 34 | mbr.wmt22.de-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.5613 35 | mbr.wmt22.de-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.5621 36 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.5624 37 | coarse-to-fine 38 | mbr.wmt22.de-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.5611 39 | mbr.wmt22.de-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.5619 40 | 41 | en-ru 42 | baselines 43 | wmt22.en-ru.beam4.ru score: 0.4989 44 | wmt22.en-ru.epsilon0.02.seed0.ru score: 0.4081 45 | chrf 46 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.chrf.ru score: 0.5538 47 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.chrf.ru score: 0.5521 48 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.ru score: 0.5541 49 | cometinho 50 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.cometinho.ru score: 0.6693 51 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.cometinho.ru score: 0.6575 52 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.ru score: 0.6669 53 | comet22 54 | mbr.wmt22.en-ru.pairwise.n1024.epsilon0.02.seed0.comet22.ru score: 0.6046 55 | mbr.wmt22.en-ru.aggregate.n1024.epsilon0.02.seed0.comet22.ru score: 0.5991 56 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.ru score: 0.6030 57 | coarse-to-fine 58 | mbr.wmt22.en-ru.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.5861 59 | mbr.wmt22.en-ru.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.ru score: 0.5878 60 | 61 | ru-en 62 | baselines 63 | wmt22.ru-en.beam4.en score: 0.6586 64 | wmt22.ru-en.epsilon0.02.seed0.en score: 0.5255 65 | chrf 66 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.chrf.en score: 0.6732 67 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.chrf.en score: 0.6695 68 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.en score: 0.6712 69 | cometinho 70 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.cometinho.en score: 0.7288 71 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.cometinho.en score: 0.7250 72 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.en score: 0.7286 73 | comet22 74 | mbr.wmt22.ru-en.pairwise.n1024.epsilon0.02.seed0.comet22.en score: 0.6856 75 | mbr.wmt22.ru-en.aggregate.n1024.epsilon0.02.seed0.comet22.en score: 0.6910 76 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.en score: 0.6871 77 | coarse-to-fine 78 | mbr.wmt22.ru-en.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.6870 79 | mbr.wmt22.ru-en.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.en score: 0.6846 80 | 81 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/run_mbr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from pathlib import Path 4 | from typing import List, Optional 5 | 6 | import jsonlines 7 | from tqdm import tqdm 8 | 9 | from experiments.reference_aggregation.experiment_utils import Testset 10 | from experiments.reference_aggregation.mbr_utils import load_utility 11 | 12 | 13 | def main(method: str, topk: Optional[int], testset: str, language_pair: str, seed_no: int, fine_utility_name: str, 14 | num_samples: int = 1024, epsilon_cutoff: float = 0.02, coarse_utility_name: str = None, 15 | limit_segments: int = None, log_time: bool = False, out_dir: Path = None) -> Path: 16 | if out_dir is None: 17 | out_dir = Path(__file__).parent 18 | 19 | if coarse_utility_name is None: 20 | coarse_utility_name = fine_utility_name 21 | 22 | if method in {'aggregate_to_fine', 'coarse_to_fine'}: 23 | assert topk <= num_samples 24 | 25 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments) 26 | 27 | samples_dir = out_dir / "samples" 28 | assert samples_dir.exists() 29 | samples_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl" 30 | assert samples_path.exists() 31 | with jsonlines.open(samples_path) as f: 32 | samples = [line["samples"] for line in f] 33 | samples = [sample[:num_samples] for sample in samples] 34 | if limit_segments is not None: 35 | samples = samples[:limit_segments] 36 | 37 | assert len(samples) == len(dataset.source_sentences) 38 | assert all(len(sample) == num_samples for sample in samples) 39 | 40 | references = samples 41 | 42 | utility = load_utility(fine_utility_name) 43 | if coarse_utility_name == fine_utility_name: 44 | coarse_utility = utility 45 | else: 46 | coarse_utility = load_utility(coarse_utility_name) 47 | 48 | translations: List[str] = [] 49 | 50 | if log_time: 51 | start_time = time.time() 52 | 53 | for i in tqdm(list(range(len(dataset.source_sentences))), desc="segments"): 54 | 55 | # For COMET: compute embeddings 56 | if hasattr(coarse_utility, "compute_features"): 57 | coarse_utility.clear_features() 58 | input_sequences = {dataset.source_sentences[i]} | set(samples[i]) | set(references[i]) 59 | coarse_utility.compute_features(input_sequences) 60 | 61 | if method == 'pairwise': 62 | n_by_n_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], samples[i], references[i], 63 | s=num_samples) 64 | translation = samples[i][n_by_n_ranking[0]] 65 | elif method == 'aggregate': 66 | aggregate_ranking = utility.rank_samples_aggregate(dataset.source_sentences[i], samples[i], references[i], 67 | s=1) 68 | translation = samples[i][aggregate_ranking[0]] 69 | elif method == 'aggregate_to_fine': 70 | aggregate_ranking = coarse_utility.rank_samples_aggregate(dataset.source_sentences[i], samples[i], 71 | references[i], s=1) 72 | topk_samples = [samples[i][aggregate_ranking[j]] for j in range(topk)] 73 | 74 | if fine_utility_name != coarse_utility_name and hasattr(utility, "compute_features"): 75 | utility.clear_features() 76 | input_sequences = {dataset.source_sentences[i]} | set(topk_samples) | set(references[i]) 77 | utility.compute_features(input_sequences) 78 | 79 | fine_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], topk_samples, references[i], 80 | s=num_samples) 81 | translation = topk_samples[fine_ranking[0]] 82 | elif method == 'coarse_to_fine': 83 | coarse_ranking = coarse_utility.rank_samples_n_by_s(dataset.source_sentences[i], samples[i], references[i], 84 | s=num_samples) 85 | topk_samples = [samples[i][coarse_ranking[j]] for j in range(topk)] 86 | 87 | if fine_utility_name != coarse_utility_name and hasattr(utility, "compute_features"): 88 | utility.clear_features() 89 | input_sequences = {dataset.source_sentences[i]} | set(topk_samples) | set(references[i]) 90 | utility.compute_features(input_sequences) 91 | 92 | fine_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], topk_samples, references[i], 93 | s=num_samples) 94 | translation = topk_samples[fine_ranking[0]] 95 | else: 96 | raise ValueError(f"Unknown method: {method}") 97 | translations.append(translation) 98 | 99 | if log_time: 100 | print(f"Average time per segment: {(time.time() - start_time) / len(dataset.source_sentences):.5f} seconds") 101 | 102 | assert len(translations) == len(dataset.source_sentences) 103 | 104 | translations_dir = out_dir / "translations" 105 | translations_dir.mkdir(exist_ok=True) 106 | out_path = translations_dir / f"mbr.{dataset}.{method}{'.top' + str(topk) if method in {'aggregate_to_fine', 'coarse_to_fine'} else ''}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{coarse_utility_name + '-to-' if coarse_utility_name != fine_utility_name else ''}{fine_utility_name}.{dataset.tgt_lang}" 107 | with open(out_path, "w") as f: 108 | for translation in translations: 109 | f.write(translation + "\n") 110 | 111 | return out_path 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--method', choices=['pairwise', 'aggregate', 'aggregate_to_fine', 'coarse_to_fine'], 117 | required=True) 118 | parser.add_argument('--topk', type=int, default=20, 119 | help='Number of samples to prune to in aggregate_to_fine method') 120 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 121 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 122 | parser.add_argument('--seed', type=int, choices=range(10), required=True, 123 | help='Index of the random seed in the list of random seeds') 124 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True) 125 | parser.add_argument('--coarse-utility', choices=['chrf', 'cometinho', 'comet22'], default=None, 126 | help='Utility used for coarse-grained method (default: same as fine-grained)') 127 | parser.add_argument('--num-samples', type=int, default=1024) 128 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02) 129 | parser.add_argument('--limit-segments', type=int, default=None, 130 | help='Limit number of segments that are processed (used for testing)') 131 | parser.add_argument('--log-time', action='store_true', 132 | help='Print average wall-clock time per segment (used for benchmarking)') 133 | args = parser.parse_args() 134 | 135 | if args.coarse_utility is None: 136 | args.coarse_utility = args.utility 137 | 138 | out_path = main(method=args.method, topk=args.topk, testset=args.testset, language_pair=args.language_pair, 139 | seed_no=args.seed, fine_utility_name=args.utility, coarse_utility_name=args.coarse_utility, 140 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, limit_segments=args.limit_segments, 141 | log_time=args.log_time, ) 142 | assert out_path.exists() 143 | print(f"Saved translations to {out_path}") 144 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/scripts/benchmark_time.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | num_segments_per_lp=32 4 | 5 | for lp in en-de de-en en-ru ru-en; do 6 | echo $lp 7 | 8 | # MBR with ChrF metric – standard MBR 9 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method pairwise --testset wmt22 --language-pair $lp --seed 0 --utility chrf --limit-segments $num_segments_per_lp --log-time 10 | # MBR with ChrF metric – reference aggregation 11 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate --testset wmt22 --language-pair $lp --seed 0 --utility chrf --limit-segments $num_segments_per_lp --log-time 12 | # MBR with ChrF metric – aggregate-to-fine MBR 13 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --utility chrf --limit-segments $num_segments_per_lp --log-time 14 | 15 | # MBR with Comethinho metric – standard MBR 16 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method pairwise --testset wmt22 --language-pair $lp --seed 0 --utility cometinho --limit-segments $num_segments_per_lp --log-time 17 | # MBR with Cometinho metric – reference aggregation 18 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate --testset wmt22 --language-pair $lp --seed 0 --utility cometinho --limit-segments $num_segments_per_lp --log-time 19 | # MBR with Cometinho metric – aggregate-to-fine MBR 20 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --utility cometinho --limit-segments $num_segments_per_lp --log-time 21 | 22 | # MBR with COMET-22 metric – standard MBR 23 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method pairwise --testset wmt22 --language-pair $lp --seed 0 --utility comet22 --limit-segments $num_segments_per_lp --log-time 24 | # MBR with COMET-22 metric – reference aggregation 25 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate --testset wmt22 --language-pair $lp --seed 0 --utility comet22 --limit-segments $num_segments_per_lp --log-time 26 | # MBR with COMET-22 metric – aggregate-to-fine MBR 27 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --utility comet22 --limit-segments $num_segments_per_lp --log-time 28 | 29 | # Coarse-to-fine MBR: ChrF to COMET-22 30 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method coarse_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --coarse-utility chrf --utility comet22 --limit-segments $num_segments_per_lp --log-time 31 | # Aggregate-to-fine MBR: Aggregate ChrF to COMET-22 32 | taskset --cpu-list 0-63 python -m experiments.reference_aggregation.run_mbr --method aggregate_to_fine --topk 20 --testset wmt22 --language-pair $lp --seed 0 --coarse-utility chrf --utility comet22 --limit-segments $num_segments_per_lp --log-time 33 | 34 | echo 35 | 36 | done 37 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/scripts/evaluate-chrf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd translations 4 | 5 | declare -a language_pairs=("en-de" "de-en" "en-ru" "ru-en") 6 | 7 | for lp in "${language_pairs[@]}" 8 | do 9 | echo $lp 10 | 11 | IFS='-' read -ra ADDR <<< "$lp" 12 | src=${ADDR[0]} 13 | tgt=${ADDR[1]} 14 | 15 | echo "baselines" 16 | sacrebleu wmt22.${lp}.ref.${tgt} -i wmt22.${lp}.beam4.${tgt} -m chrf -b 17 | sacrebleu wmt22.${lp}.ref.${tgt} -i wmt22.${lp}.epsilon0.02.seed0.${tgt} -m chrf -b 18 | 19 | echo "chrf" 20 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.chrf.${tgt} -m chrf -b 21 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.chrf.${tgt} -m chrf -b 22 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.${tgt} -m chrf -b 23 | 24 | echo "cometinho" 25 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.cometinho.${tgt} -m chrf -b 26 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.cometinho.${tgt} -m chrf -b 27 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.${tgt} -m chrf -b 28 | 29 | echo "comet22" 30 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.comet22.${tgt} -m chrf -b 31 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.comet22.${tgt} -m chrf -b 32 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.${tgt} -m chrf -b 33 | 34 | echo "coarse-to-fine" 35 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} -m chrf -b 36 | sacrebleu wmt22.${lp}.ref.${tgt} -i mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} -m chrf -b 37 | 38 | echo 39 | done 40 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/scripts/evaluate-comet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model=$1 4 | if [ "$model" != "Unbabel/wmt22-comet-da" ] && [ "$model" != "Unbabel/eamt22-cometinho-da" ] && [ "$model" != "Unbabel/XCOMET-XL" ]; then 5 | echo "Invalid model. Please choose one of Unbabel/wmt22-comet-da, Unbabel/eamt22-cometinho-da, Unbabel/XCOMET-XL" 6 | exit 1 7 | fi 8 | 9 | cd translations 10 | 11 | declare -a language_pairs=("en-de" "de-en" "en-ru" "ru-en") 12 | 13 | for lp in "${language_pairs[@]}" 14 | do 15 | echo $lp 16 | 17 | IFS='-' read -ra ADDR <<< "$lp" 18 | src=${ADDR[0]} 19 | tgt=${ADDR[1]} 20 | 21 | echo "baselines" 22 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t wmt22.${lp}.beam4.${tgt} --only_system --model $1 23 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t wmt22.${lp}.epsilon0.02.seed0.${tgt} --only_system --model $1 24 | 25 | echo "chrf" 26 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.chrf.${tgt} --only_system --model $1 27 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.chrf.${tgt} --only_system --model $1 28 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf.${tgt} --only_system --model $1 29 | 30 | echo "cometinho" 31 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.cometinho.${tgt} --only_system --model $1 32 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.cometinho.${tgt} --only_system --model $1 33 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.cometinho.${tgt} --only_system --model $1 34 | 35 | echo "comet22" 36 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.pairwise.n1024.epsilon0.02.seed0.comet22.${tgt} --only_system --model $1 37 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate.n1024.epsilon0.02.seed0.comet22.${tgt} --only_system --model $1 38 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.comet22.${tgt} --only_system --model $1 39 | 40 | echo "coarse-to-fine" 41 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.coarse_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} --only_system --model $1 42 | comet-score -s wmt22.${lp}.src.${src} -r wmt22.${lp}.ref.${tgt} -t mbr.wmt22.${lp}.aggregate_to_fine.top20.n1024.epsilon0.02.seed0.chrf-to-comet22.${tgt} --only_system --model $1 43 | 44 | echo 45 | done 46 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/scripts/plot_accuracy_reverse.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compared to plot_accuracy.py, makes plotting easier by reversing the labels (1024->1 etc.) 3 | """ 4 | import argparse 5 | 6 | from experiments.reference_aggregation.plot_accuracy import main 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 10 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 11 | parser.add_argument('--seed', type=int, choices=range(10), required=True, 12 | help='Index of the random seed in the list of random seeds') 13 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True) 14 | parser.add_argument('--coarse-utility', choices=['chrf', 'cometinho', 'comet22'], default=None, 15 | help='Utility used for coarse-grained method (default: same as fine-grained)') 16 | parser.add_argument('--topk', type=int, default=20, 17 | help='Number of top translations that have been saved in the jsonl file') 18 | parser.add_argument('--method', choices=['n_by_s', 'aggregate'], required=True) 19 | parser.add_argument('--num-samples', type=int, default=1024) 20 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02) 21 | parser.add_argument('--accuracy-topk', type=int, default=None, 22 | help='Number of top translations that are used to compute the accuracy (default: same as data-topk)') 23 | parser.add_argument('--limit-segments', type=int, default=None, 24 | help='Limit number of segments that are processed (used for testing)') 25 | args = parser.parse_args() 26 | 27 | if args.coarse_utility is None: 28 | args.coarse_utility = args.utility 29 | if args.accuracy_topk is None: 30 | args.accuracy_topk = args.topk 31 | 32 | series = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed, fine_utility_name=args.utility, 33 | coarse_utility_name=args.coarse_utility, topk=args.topk, method=args.method, num_samples=args.num_samples, 34 | epsilon_cutoff=args.epsilon_cutoff, accuracy_topk=args.accuracy_topk, limit_segments=args.limit_segments, ) 35 | 36 | s_values = [s for s, _ in series] 37 | reversed_s_values = list(reversed(s_values)) 38 | series_str = "".join( 39 | [f"({s},{accuracy:.5f})" for s, accuracy in zip(reversed_s_values, [accuracy for _, accuracy in series])]) 40 | print( 41 | f"Testset: {args.testset}, language pair: {args.language_pair}, seed: {args.seed}, fine utility: {args.utility}, coarse utility: {args.coarse_utility}, topk: {args.topk}, method: {args.method}") 42 | print(f"Top-{args.accuracy_topk} accuracy:") 43 | print(series_str) 44 | print() 45 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/scripts/print_data_stats_table.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import jsonlines 4 | 5 | header = "\\begin{tabularx}{\\textwidth}{Xrrr}\n\\toprule\n" 6 | header += "& \\# Segments & \\# Samples per segment & \\# Unique samples per segment \\\\\n\\midrule\n" 7 | footer = "\\bottomrule\n\\end{tabularx}" 8 | 9 | samples_dir = Path(__file__).parent.parent / "samples" 10 | 11 | body = "" 12 | body += "\\textit{newstest21} & & & \\\\\n" 13 | for lang_pair in ["en-de", "de-en", "en-ru", "ru-en"]: 14 | path = samples_dir / f"samples.wmt21.{lang_pair}.n1024.epsilon0.02.seed0.jsonl" 15 | assert path.exists(), f"Path {path} does not exist" 16 | with jsonlines.open(path) as reader: 17 | data = list(reader) 18 | num_segments = len(data) 19 | num_samples = len(data[0]["samples"]) 20 | avg_num_unique_samples = sum( 21 | [len(set([sample for sample in segment["samples"]])) for segment in data]) / num_segments 22 | body += "\\textsc{" + lang_pair.replace('-', '–') + "} & " + str(num_segments) + " & " + str( 23 | num_samples) + " & " + "{:.1f}".format(avg_num_unique_samples) + " \\\\\n" 24 | body += "\\addlinespace\n" 25 | body += "\\textit{newstest22} & & & \\\\\n" 26 | for lang_pair in ["en-de", "de-en", "en-ru", "ru-en"]: 27 | path = samples_dir / f"samples.wmt22.{lang_pair}.n1024.epsilon0.02.seed0.jsonl" 28 | assert path.exists(), f"Path {path} does not exist" 29 | with jsonlines.open(path) as reader: 30 | data = list(reader) 31 | num_segments = len(data) 32 | num_samples = len(data[0]["samples"]) 33 | avg_num_unique_samples = sum( 34 | [len(set([sample for sample in segment["samples"]])) for segment in data]) / num_segments 35 | body += "\\textsc{" + lang_pair.replace('-', '–') + "} & " + str(num_segments) + " & " + str( 36 | num_samples) + " & " + "{:.1f}".format(avg_num_unique_samples) + " \\\\\n" 37 | 38 | print(header + body + footer) 39 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/scripts/save_src_and_ref.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from experiments.reference_aggregation.experiment_utils import Testset 5 | 6 | 7 | def main(testset: str, language_pair: str, limit_segments: int = None, out_dir: Path = None) -> (Path, Path): 8 | if out_dir is None: 9 | out_dir = Path(__file__).parent 10 | 11 | translations_dir = out_dir / "translations" 12 | translations_dir.mkdir(exist_ok=True) 13 | 14 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments) 15 | 16 | src_out_path = translations_dir / f"{dataset}.src.{dataset.src_lang}" 17 | 18 | with open(src_out_path, "w") as f: 19 | for src in dataset.source_sentences: 20 | f.write(src + "\n") 21 | 22 | ref_out_path = translations_dir / f"{dataset}.ref.{dataset.tgt_lang}" 23 | with open(ref_out_path, "w") as f: 24 | for ref in dataset.references: 25 | f.write(ref + "\n") 26 | 27 | return src_out_path, ref_out_path 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 33 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 34 | parser.add_argument('--limit-segments', type=int, default=None, 35 | help='Limit number of segments that are processed (used for testing)') 36 | args = parser.parse_args() 37 | 38 | src_path, ref_path = main(testset=args.testset, language_pair=args.language_pair, 39 | limit_segments=args.limit_segments, ) 40 | print(f"Source sentences saved to {src_path}") 41 | print(f"References saved to {ref_path}") 42 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/experiments/reference_aggregation/tests/__init__.py -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_beam_search.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | 5 | class BeamSearchTestCase(TestCase): 6 | 7 | def setUp(self): 8 | self.testset = "wmt21" 9 | self.language_pair = "en-de" 10 | self.test_dir = Path(__file__).parent / "out" 11 | self.test_dir.mkdir(exist_ok=True) 12 | 13 | def test_beam_search(self): 14 | from experiments.reference_aggregation.baseline_beam_search import main 15 | out_path = main(self.testset, self.language_pair, limit_segments=4, out_dir=self.test_dir) 16 | self.assertTrue(out_path.exists()) 17 | self.assertIn(self.test_dir, out_path.parents) 18 | self.assertTrue(out_path.name.endswith(".de")) 19 | translations = out_path.read_text().splitlines() 20 | self.assertEqual(len(translations), 4) 21 | print(translations[0]) 22 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_chrf.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from experiments.reference_aggregation.mbr_utils import ChrfUtility 4 | 5 | 6 | class CometTestCase(TestCase): 7 | 8 | def setUp(self): 9 | self.chrf = ChrfUtility() 10 | 11 | def test_rank_samples_n_by_n(self): 12 | source_sequence = "This is a sample sentence" 13 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 14 | "Dies ist ein Test."] 15 | references = samples 16 | 17 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references, s=4) 18 | self.assertEqual(len(samples), len(indices)) 19 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 20 | 21 | # Test sample order invariance 22 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=4) 23 | self.assertListEqual([2, 3, 0, 1], indices.tolist()) 24 | 25 | # Test reference order invariance 26 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references[::-1], s=4) 27 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 28 | 29 | def test_rank_samples_n_by_1(self): 30 | source_sequence = "This is a sample sentence" 31 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 32 | "Dies ist ein Test."] 33 | references = samples 34 | 35 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references, s=1) 36 | self.assertEqual(0, indices[0]) # Perfect match with itself 37 | self.assertListEqual([0, 1, 3, 2], indices.tolist()) 38 | 39 | # Test sample order invariance 40 | indices = self.chrf.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=1) 41 | self.assertListEqual([3, 2, 0, 1], indices.tolist()) 42 | 43 | def test_rank_samples_aggregate(self): 44 | source_sequence = "This is a sample sentence" 45 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 46 | "Dies ist ein Test."] 47 | references = samples 48 | 49 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references, s=1) 50 | self.assertEqual(len(samples), len(indices)) 51 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 52 | 53 | # Test sample order invariance 54 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples[::-1], references, s=1) 55 | self.assertListEqual([2, 3, 0, 1], indices.tolist()) 56 | 57 | # Test reference order invariance 58 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references[::-1], s=1) 59 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 60 | 61 | def test_rank_samples_aggregate_partial(self): 62 | source_sequence = "This is a sample sentence" 63 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 64 | "Dies ist ein Test."] 65 | references = samples 66 | 67 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references, s=2) 68 | self.assertEqual(len(samples), len(indices)) 69 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 70 | 71 | # Test sample order invariance 72 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples[::-1], references, s=2) 73 | self.assertListEqual([2, 3, 0, 1], indices.tolist()) 74 | 75 | # Test (partial) reference order invariance: change order of references within partitions 76 | indices = self.chrf.rank_samples_aggregate(source_sequence, samples, 77 | references[:2][::-1] + references[2:][::-1], s=2) 78 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 79 | 80 | def test_rank_samples_disaggregated_is_equivalent_to_n_by_n(self): 81 | source_sequence = "This is a sample sentence" 82 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 83 | "Dies ist ein Test."] 84 | references = samples 85 | 86 | n_by_n_indices = self.chrf.rank_samples_n_by_s(source_sequence, samples, references, s=4) 87 | aggregate_indices = self.chrf.rank_samples_aggregate(source_sequence, samples, references, s=4) 88 | self.assertListEqual(n_by_n_indices.tolist(), aggregate_indices.tolist()) 89 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_comet.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from experiments.reference_aggregation.mbr_utils import CometUtility 4 | 5 | 6 | class CometTestCase(TestCase): 7 | 8 | def setUp(self): 9 | self.comet = CometUtility("eamt22-cometinho-da") 10 | 11 | def test_compute_features(self): 12 | self.assertEqual(0, len(self.comet.embeddings)) 13 | self.comet.compute_features({"This is a test.", "Dies ist ein Test."}) 14 | self.assertEqual(2, len(self.comet.embeddings)) 15 | self.comet.clear_features() 16 | self.assertEqual(0, len(self.comet.embeddings)) 17 | 18 | def test_rank_samples_n_by_n(self): 19 | source_sequence = "This is a sample sentence" 20 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 21 | "Dies ist ein Test."] 22 | references = samples 23 | self.comet.compute_features({source_sequence} | set(samples) | set(references)) 24 | 25 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references, s=4) 26 | self.assertEqual(len(samples), len(indices)) 27 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 28 | 29 | # Test sample order invariance 30 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=4) 31 | self.assertListEqual([2, 3, 0, 1], indices.tolist()) 32 | 33 | # Test reference order invariance 34 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references[::-1], s=4) 35 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 36 | 37 | def test_rank_samples_n_by_1(self): 38 | source_sequence = "This is a sample sentence" 39 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 40 | "Dies ist ein Test."] 41 | references = samples 42 | self.comet.compute_features({source_sequence} | set(samples) | set(references)) 43 | 44 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references, s=1) 45 | self.assertEqual(0, indices[0]) # Perfect match with itself 46 | self.assertListEqual([0, 1, 3, 2], indices.tolist()) 47 | 48 | # Test sample order invariance 49 | indices = self.comet.rank_samples_n_by_s(source_sequence, samples[::-1], references, s=1) 50 | self.assertListEqual([3, 2, 0, 1], indices.tolist()) 51 | 52 | def test_rank_samples_aggregate(self): 53 | source_sequence = "This is a sample sentence" 54 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 55 | "Dies ist ein Test."] 56 | references = samples 57 | self.comet.compute_features({source_sequence} | set(samples) | set(references)) 58 | 59 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, references, s=1) 60 | self.assertEqual(len(samples), len(indices)) 61 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 62 | 63 | # Test sample order invariance 64 | indices = self.comet.rank_samples_aggregate(source_sequence, samples[::-1], references, s=1) 65 | self.assertListEqual([2, 3, 0, 1], indices.tolist()) 66 | 67 | # Test reference order invariance 68 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, references[::-1], s=1) 69 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 70 | 71 | def test_rank_samples_aggregate_partial(self): 72 | source_sequence = "This is a sample sentence" 73 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 74 | "Dies ist ein Test."] 75 | references = samples 76 | self.comet.compute_features({source_sequence} | set(samples) | set(references)) 77 | 78 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, references, s=2) 79 | self.assertEqual(len(samples), len(indices)) 80 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 81 | 82 | # Test sample order invariance 83 | indices = self.comet.rank_samples_aggregate(source_sequence, samples[::-1], references, s=2) 84 | self.assertListEqual([2, 3, 0, 1], indices.tolist()) 85 | 86 | # Test (partial) reference order invariance: change order of references within partitions 87 | indices = self.comet.rank_samples_aggregate(source_sequence, samples, 88 | references[:2][::-1] + references[2:][::-1], s=2) 89 | self.assertListEqual([1, 0, 3, 2], indices.tolist()) 90 | 91 | def test_rank_samples_disaggregated_is_equivalent_to_n_by_n(self): 92 | source_sequence = "This is a sample sentence" 93 | samples = ["Dies ist ein Beispiel.", "Dies ist ein Beispielsatz", "Dieser Satz macht keinen Sinn.", 94 | "Dies ist ein Test."] 95 | references = samples 96 | self.comet.compute_features({source_sequence} | set(samples) | set(references)) 97 | 98 | n_by_n_indices = self.comet.rank_samples_n_by_s(source_sequence, samples, references, s=4) 99 | aggregate_indices = self.comet.rank_samples_aggregate(source_sequence, samples, references, s=4) 100 | self.assertListEqual(n_by_n_indices.tolist(), aggregate_indices.tolist()) 101 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_epsilon_sampling.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | from experiments.reference_aggregation.fairseq_utils import load_model 5 | 6 | 7 | class EpsilonSamplingTestCase(TestCase): 8 | 9 | def setUp(self): 10 | self.testset = "wmt21" 11 | self.language_pair = "en-de" 12 | self.test_dir = Path(__file__).parent / "out" 13 | self.test_dir.mkdir(exist_ok=True) 14 | 15 | def test_epsilon_sampling(self): 16 | model = load_model(self.language_pair) 17 | source_sentence = "This is a test." 18 | num_samples = 4 19 | # ε=0.02 20 | samples = model.sample(num_samples * [source_sentence], seed=42, sampling_epsilon_cutoff=0.02) 21 | self.assertEqual(len(samples), num_samples) 22 | self.assertIsInstance(samples[0], str) 23 | print(samples[0]) 24 | # ε=0 25 | samples = model.sample(num_samples * [source_sentence], seed=42, sampling_epsilon_cutoff=0) 26 | self.assertEqual(len(samples), num_samples) 27 | self.assertIsInstance(samples[0], str) 28 | 29 | def test_extract_translations(self): 30 | # Generate samples 31 | from experiments.reference_aggregation.generate_samples import main as generate_samples 32 | jsonl_path = generate_samples(self.testset, self.language_pair, seed_no=0, num_samples=8, epsilon_cutoff=0.02, 33 | limit_segments=4, out_dir=self.test_dir) 34 | self.assertTrue(jsonl_path.exists()) 35 | # Extract 36 | from experiments.reference_aggregation.baseline_epsilon_sampling import main 37 | out_path = main(self.testset, self.language_pair, num_samples=8, epsilon_cutoff=0.02, seed_no=0, 38 | out_dir=self.test_dir) 39 | self.assertTrue(out_path.exists()) 40 | self.assertIn(self.test_dir, out_path.parents) 41 | self.assertTrue(out_path.name.endswith(".de")) 42 | translations = out_path.read_text().splitlines() 43 | self.assertEqual(len(translations), 4) 44 | print(translations[0]) 45 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_generate_samples.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | import jsonlines 5 | 6 | 7 | class GenerateSamplesTestCase(TestCase): 8 | 9 | def setUp(self): 10 | self.testset = "wmt21" 11 | self.language_pair = "en-de" 12 | self.test_dir = Path(__file__).parent / "out" 13 | self.test_dir.mkdir(exist_ok=True) 14 | 15 | def test_generate_samples(self): 16 | from experiments.reference_aggregation.generate_samples import main 17 | out_path = main(self.testset, self.language_pair, seed_no=0, num_samples=8, epsilon_cutoff=0.02, 18 | limit_segments=4, out_dir=self.test_dir) 19 | self.assertTrue(out_path.exists()) 20 | self.assertIn(self.test_dir, out_path.parents) 21 | self.assertTrue(out_path.name.endswith(".jsonl")) 22 | with jsonlines.open(out_path) as f: 23 | data = list(f) 24 | self.assertEqual(len(data), 4) 25 | for line in data: 26 | self.assertIn("samples", line) 27 | self.assertEqual(len(line["samples"]), 8) 28 | self.assertIsInstance(line["samples"][0], str) 29 | print(data[0]["samples"][0]) 30 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_run_mbr.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | from experiments.reference_aggregation.run_mbr import main 5 | 6 | 7 | class MBRTestCase(TestCase): 8 | 9 | def setUp(self): 10 | self.testset = "wmt21" 11 | self.language_pair = "en-de" 12 | self.test_dir = Path(__file__).parent / "out" 13 | self.test_dir.mkdir(exist_ok=True) 14 | 15 | def test_run_mbr_pairwise_chrf(self): 16 | out_path = main(method="pairwise", topk=None, testset=self.testset, language_pair=self.language_pair, seed_no=0, 17 | fine_utility_name="chrf", num_samples=8, epsilon_cutoff=0.02, limit_segments=4, 18 | out_dir=self.test_dir) 19 | self.assertTrue(out_path.exists()) 20 | self.assertIn(self.test_dir, out_path.parents) 21 | self.assertTrue(out_path.name.endswith(".de")) 22 | translations = out_path.read_text().splitlines() 23 | self.assertEqual(len(translations), 4) 24 | print(translations[0]) 25 | 26 | def test_run_mbr_aggregate_chrf(self): 27 | out_path = main(method="aggregate", topk=None, testset=self.testset, language_pair=self.language_pair, 28 | seed_no=0, fine_utility_name="chrf", num_samples=8, epsilon_cutoff=0.02, limit_segments=4, 29 | out_dir=self.test_dir) 30 | translations = out_path.read_text().splitlines() 31 | self.assertEqual(len(translations), 4) 32 | print(translations[0]) 33 | 34 | def test_run_mbr_aggregate_to_fine_chrf(self): 35 | out_path = main(method="aggregate_to_fine", topk=2, testset=self.testset, language_pair=self.language_pair, 36 | seed_no=0, fine_utility_name="chrf", num_samples=8, epsilon_cutoff=0.02, limit_segments=4, 37 | out_dir=self.test_dir) 38 | translations = out_path.read_text().splitlines() 39 | self.assertEqual(len(translations), 4) 40 | print(translations[0]) 41 | 42 | def test_run_mbr_coarse_to_fine_chrf_to_comet22(self): 43 | out_path = main(method="coarse_to_fine", topk=2, testset=self.testset, language_pair=self.language_pair, 44 | seed_no=0, coarse_utility_name="chrf", fine_utility_name="cometinho", num_samples=8, 45 | epsilon_cutoff=0.02, limit_segments=4, out_dir=self.test_dir) 46 | translations = out_path.read_text().splitlines() 47 | self.assertEqual(len(translations), 4) 48 | print(translations[0]) 49 | 50 | def test_run_mbr_aggregate_to_fine_chrf_to_comet22(self): 51 | out_path = main(method="aggregate_to_fine", topk=2, testset=self.testset, language_pair=self.language_pair, 52 | seed_no=0, coarse_utility_name="chrf", fine_utility_name="cometinho", num_samples=8, 53 | epsilon_cutoff=0.02, limit_segments=4, out_dir=self.test_dir) 54 | translations = out_path.read_text().splitlines() 55 | self.assertEqual(len(translations), 4) 56 | print(translations[0]) 57 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_save_src_and_ref.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | 5 | class SaveSrcAndRefTestCase(TestCase): 6 | 7 | def setUp(self): 8 | self.testset = "wmt21" 9 | self.language_pair = "en-de" 10 | self.test_dir = Path(__file__).parent / "out" 11 | self.test_dir.mkdir(exist_ok=True) 12 | 13 | def test_save_src_and_ref(self): 14 | from experiments.reference_aggregation.scripts.save_src_and_ref import main 15 | src_path, ref_path = main(self.testset, self.language_pair, limit_segments=4, out_dir=self.test_dir) 16 | self.assertTrue(src_path.exists()) 17 | self.assertIn(self.test_dir, src_path.parents) 18 | self.assertTrue(src_path.name.endswith(".en")) 19 | self.assertTrue(ref_path.exists()) 20 | self.assertIn(self.test_dir, ref_path.parents) 21 | self.assertTrue(ref_path.name.endswith(".de")) 22 | source_sentences = src_path.read_text().splitlines() 23 | self.assertEqual(len(source_sentences), 4) 24 | print(source_sentences[0]) 25 | reference_sentences = ref_path.read_text().splitlines() 26 | self.assertEqual(len(reference_sentences), 4) 27 | print(reference_sentences[0]) 28 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_testset.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from experiments.reference_aggregation.experiment_utils import Testset 4 | 5 | 6 | class TestsetTestCase(TestCase): 7 | 8 | def setUp(self): 9 | self.testsets = ["wmt21", "wmt22"] 10 | self.language_pairs = ["en-de", "de-en", "en-ru", "ru-en"] 11 | 12 | def test_load_testsets(self): 13 | for testset in self.testsets: 14 | for language_pair in self.language_pairs: 15 | data = Testset.from_wmt(testset, language_pair) 16 | self.assertEqual(language_pair, f"{data.src_lang}-{data.tgt_lang}") 17 | self.assertEqual(len(data.source_sentences), len(data.references)) 18 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/tests/test_validation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from unittest import TestCase 4 | 5 | import jsonlines 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | class ValidationTestCase(TestCase): 11 | 12 | def setUp(self): 13 | self.testset = "wmt21" 14 | self.language_pair = "en-de" 15 | self.test_dir = Path(__file__).parent / "out" 16 | self.test_dir.mkdir(exist_ok=True) 17 | 18 | def test_run_validation_cometinho(self): 19 | from experiments.reference_aggregation.validation import main 20 | jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=4, num_samples=8, 21 | limit_segments=4, out_dir=self.test_dir) 22 | self.assertTrue(jsonl_path.exists()) 23 | with jsonlines.open(jsonl_path) as f: 24 | data = list(f) 25 | 26 | n_by_s_lines = [line for line in data if line["method"] == "n_by_s"] 27 | s_values = [line["s"] for line in n_by_s_lines] 28 | self.assertEqual([8, 4, 2, 1], s_values) 29 | 30 | aggregate_lines = [line for line in data if line["method"] == "aggregate"] 31 | s_values = [line["s"] for line in aggregate_lines] 32 | self.assertEqual([8, 4, 2, 1], s_values) 33 | 34 | for line in data: 35 | self.assertEqual(4, len(line["rankings"])) 36 | self.assertEqual(4, len(line["rankings"][0])) 37 | self.assertEqual(4, len(set(line["rankings"][0]))) 38 | 39 | test_translation_paths = [ 40 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s8.{self.language_pair.split('-')[1]}", 41 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s4.{self.language_pair.split('-')[1]}", 42 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s2.{self.language_pair.split('-')[1]}", 43 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.n_by_s.s1.{self.language_pair.split('-')[1]}", 44 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s8.{self.language_pair.split('-')[1]}", 45 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s4.{self.language_pair.split('-')[1]}", 46 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s2.{self.language_pair.split('-')[1]}", 47 | self.test_dir / "translations" / f"validation.{self.testset}.{self.language_pair}.n8.epsilon0.02.seed0.cometinho.aggregate.s1.{self.language_pair.split('-')[1]}", ] 48 | for translation_path in test_translation_paths: 49 | self.assertTrue(translation_path.exists()) 50 | self.assertIn(self.test_dir, translation_path.parents) 51 | self.assertTrue(translation_path.name.endswith(".de")) 52 | translations = translation_path.read_text().splitlines() 53 | self.assertEqual(len(translations), 4) 54 | 55 | def test_run_validation_topk_first_identical(self): 56 | """ 57 | The top translation should be the same of any k 58 | """ 59 | from experiments.reference_aggregation.validation import main 60 | 61 | # cometinho 62 | top1_jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=1, 63 | num_samples=8, limit_segments=1, out_dir=self.test_dir) 64 | with jsonlines.open(top1_jsonl_path) as f: 65 | top1_best_indices = [line["rankings"][0][0] for line in f] 66 | for k in [1, 2, 4, 8]: 67 | jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=k, 68 | num_samples=8, limit_segments=1, out_dir=self.test_dir) 69 | with jsonlines.open(jsonl_path) as f: 70 | best_indices = [line["rankings"][0][0] for line in f] 71 | self.assertEqual(top1_best_indices, best_indices) 72 | 73 | # chrf 74 | top1_jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="chrf", topk=1, num_samples=8, 75 | limit_segments=1, out_dir=self.test_dir) 76 | with jsonlines.open(top1_jsonl_path) as f: 77 | top1_best_indices = [line["rankings"][0][0] for line in f] 78 | for k in [1, 2, 4, 8]: 79 | jsonl_path = main(self.testset, self.language_pair, seed_no=0, utility_name="chrf", topk=k, num_samples=8, 80 | limit_segments=1, out_dir=self.test_dir) 81 | with jsonlines.open(jsonl_path) as f: 82 | best_indices = [line["rankings"][0][0] for line in f] 83 | self.assertEqual(top1_best_indices, best_indices) 84 | 85 | def test_plot_accuracy(self): 86 | # Run validation.py 87 | from experiments.reference_aggregation.validation import main as validation 88 | jsonl_path = validation(self.testset, self.language_pair, seed_no=0, utility_name="cometinho", topk=8, 89 | num_samples=8, limit_segments=4, out_dir=self.test_dir) 90 | self.assertTrue(jsonl_path.exists()) 91 | 92 | # Top-8 93 | from experiments.reference_aggregation.plot_accuracy import main as plot_accuracy 94 | series_n_by_s_top8 = plot_accuracy(self.testset, self.language_pair, seed_no=0, fine_utility_name="cometinho", 95 | topk=8, method="n_by_s", num_samples=8, accuracy_topk=8, limit_segments=4, 96 | out_dir=self.test_dir) 97 | series_aggregate_top8 = plot_accuracy(self.testset, self.language_pair, seed_no=0, 98 | fine_utility_name="cometinho", topk=8, method="aggregate", num_samples=8, 99 | accuracy_topk=8, limit_segments=4, out_dir=self.test_dir) 100 | 101 | # Top-1 102 | series_n_by_s_top1 = plot_accuracy(self.testset, self.language_pair, seed_no=0, fine_utility_name="cometinho", 103 | topk=8, accuracy_topk=1, method="n_by_s", num_samples=8, limit_segments=4, 104 | out_dir=self.test_dir) 105 | series_aggregate_top1 = plot_accuracy(self.testset, self.language_pair, seed_no=0, 106 | fine_utility_name="cometinho", topk=8, accuracy_topk=1, 107 | method="aggregate", num_samples=8, limit_segments=4, 108 | out_dir=self.test_dir) 109 | 110 | # Assert that top-1 accuracy <= top-8 accuracy 111 | for (s, accuracy_top8), (_, accuracy_top1) in zip(series_n_by_s_top8, series_n_by_s_top1): 112 | self.assertLessEqual(accuracy_top1, accuracy_top8) 113 | for (s, accuracy_top8), (_, accuracy_top1) in zip(series_aggregate_top8, series_aggregate_top1): 114 | self.assertLessEqual(accuracy_top1, accuracy_top8) 115 | 116 | # Top-4 117 | series_n_by_s_top4 = plot_accuracy(self.testset, self.language_pair, seed_no=0, fine_utility_name="cometinho", 118 | topk=8, accuracy_topk=4, method="n_by_s", num_samples=8, limit_segments=4, 119 | out_dir=self.test_dir) 120 | series_aggregate_top4 = plot_accuracy(self.testset, self.language_pair, seed_no=0, 121 | fine_utility_name="cometinho", topk=8, accuracy_topk=4, 122 | method="aggregate", num_samples=8, limit_segments=4, 123 | out_dir=self.test_dir) 124 | 125 | # Assert that top-4 accuracy <= top-8 accuracy 126 | for (s, accuracy_top8), (_, accuracy_top4) in zip(series_n_by_s_top8, series_n_by_s_top4): 127 | self.assertLessEqual(accuracy_top4, accuracy_top8) 128 | for (s, accuracy_top8), (_, accuracy_top4) in zip(series_aggregate_top8, series_aggregate_top4): 129 | self.assertLessEqual(accuracy_top4, accuracy_top8) 130 | 131 | # Assert that top-1 accuracy <= top-4 accuracy 132 | for (s, accuracy_top4), (_, accuracy_top1) in zip(series_n_by_s_top4, series_n_by_s_top1): 133 | self.assertLessEqual(accuracy_top1, accuracy_top4) 134 | for (s, accuracy_top4), (_, accuracy_top1) in zip(series_aggregate_top4, series_aggregate_top1): 135 | self.assertLessEqual(accuracy_top1, accuracy_top4) 136 | -------------------------------------------------------------------------------- /experiments/reference_aggregation/validation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import jsonlines 7 | from tqdm import tqdm 8 | 9 | from experiments.reference_aggregation.experiment_utils import Testset 10 | from experiments.reference_aggregation.mbr_utils import load_utility 11 | 12 | 13 | def main(testset: str, language_pair: str, seed_no: int, utility_name: str, chrf_eps_smoothing: bool = False, 14 | topk: int = 20, num_samples: int = 1024, epsilon_cutoff: float = 0.02, limit_segments: int = None, 15 | out_dir: Path = None) -> Path: 16 | if out_dir is None: 17 | out_dir = Path(__file__).parent 18 | 19 | assert topk <= num_samples 20 | 21 | dataset = Testset.from_wmt(testset, language_pair, limit_segments=limit_segments) 22 | 23 | samples_dir = out_dir / "samples" 24 | assert samples_dir.exists() 25 | samples_path = samples_dir / f"samples.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.jsonl" 26 | assert samples_path.exists() 27 | with jsonlines.open(samples_path) as f: 28 | samples = [line["samples"] for line in f] 29 | samples = [sample[:num_samples] for sample in samples] 30 | if limit_segments is not None: 31 | samples = samples[:limit_segments] 32 | 33 | assert len(samples) == len(dataset.source_sentences) 34 | assert all(len(sample) == num_samples for sample in samples) 35 | 36 | references = samples 37 | 38 | # s = n/1, n/2, n/4, n/8, ..., n/n 39 | s_values = [int(num_samples / 2 ** i) for i in range(int(math.log2(num_samples)) + 1)] 40 | assert s_values[0] == num_samples 41 | assert s_values[-1] == 1 42 | 43 | utility = load_utility(utility_name) 44 | 45 | if utility_name == "chrf" and chrf_eps_smoothing: 46 | utility.eps_smoothing = True 47 | 48 | # Compute rankings for n-by-s and aggregate, for each s 49 | n_by_s_rankings: List[List[List[int]]] = [] # segments x s_values x topk 50 | aggregate_rankings: List[List[List[int]]] = [] # segments x s_values x topk 51 | for i in tqdm(list(range(len(dataset.source_sentences))), desc="segments"): 52 | 53 | # For COMET: compute embeddings 54 | if hasattr(utility, "compute_features"): 55 | utility.clear_features() 56 | input_sequences = {dataset.source_sentences[i]} | set(samples[i]) | set(references[i]) 57 | utility.compute_features(input_sequences) 58 | 59 | n_by_s_rankings.append([]) 60 | for s in s_values: 61 | n_by_s_ranking = utility.rank_samples_n_by_s(dataset.source_sentences[i], samples[i], references[i], s=s) 62 | n_by_s_ranking = n_by_s_ranking[:topk] 63 | n_by_s_rankings[-1].append(n_by_s_ranking.tolist()) 64 | aggregate_rankings.append([]) 65 | for s in s_values: 66 | aggregate_ranking = utility.rank_samples_aggregate(dataset.source_sentences[i], samples[i], references[i], 67 | s=s) 68 | aggregate_ranking = aggregate_ranking[:topk] 69 | aggregate_rankings[-1].append(aggregate_ranking.tolist()) 70 | 71 | # Save top-k rankings to jsonl file 72 | output_dir = out_dir / "validation_output" 73 | output_dir.mkdir(exist_ok=True) 74 | output_path = output_dir / f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{utility_name}{'-eps' if chrf_eps_smoothing else ''}.top{topk}.jsonl" 75 | with jsonlines.open(output_path, mode="w") as f: 76 | for i, s in enumerate(s_values): 77 | f.write({"method": "n_by_s", "s": s, "rankings": [ranking[i] for ranking in n_by_s_rankings]}) 78 | for i, s in enumerate(s_values): 79 | f.write({"method": "aggregate", "s": s, "rankings": [ranking[i] for ranking in aggregate_rankings]}) 80 | 81 | translations_dir = out_dir / "translations" 82 | translations_dir.mkdir(exist_ok=True) 83 | translations_prefix = f"validation.{dataset}.n{num_samples}.epsilon{epsilon_cutoff}.seed{seed_no}.{utility_name}{'-eps' if chrf_eps_smoothing else ''}" 84 | 85 | # Save top-1 translations for n-by-s 86 | for j, s in enumerate(s_values): 87 | n_by_s_translations_path = translations_dir / f"{translations_prefix}.n_by_s.s{s}.{dataset.tgt_lang}" 88 | with open(n_by_s_translations_path, "w") as f: 89 | for i, rankings in enumerate(n_by_s_rankings): 90 | ranking = rankings[j] 91 | f.write(samples[i][ranking[0]] + "\n") 92 | 93 | # Save top-1 translations for aggregate 94 | for j, s in enumerate(s_values): 95 | aggregate_translations_path = translations_dir / f"{translations_prefix}.aggregate.s{s}.{dataset.tgt_lang}" 96 | with open(aggregate_translations_path, "w") as f: 97 | for i, rankings in enumerate(aggregate_rankings): 98 | ranking = rankings[j] 99 | f.write(samples[i][ranking[0]] + "\n") 100 | 101 | return output_path 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--testset', choices=['wmt21', 'wmt22'], required=True) 107 | parser.add_argument('--language-pair', choices=['de-en', 'en-de', 'en-ru', 'ru-en'], required=True) 108 | parser.add_argument('--seed', type=int, choices=range(10), required=True, 109 | help='Index of the random seed in the list of random seeds') 110 | parser.add_argument('--utility', choices=['chrf', 'cometinho', 'comet22'], required=True) 111 | parser.add_argument('--chrf-eps-smoothing', action='store_true', 112 | help='Use epsilon smoothing for ChrF (default: False = effective order smoothing)') 113 | parser.add_argument('--topk', type=int, default=20, help='Number of top translations to save in the jsonl file') 114 | parser.add_argument('--num-samples', type=int, default=1024) 115 | parser.add_argument('--epsilon-cutoff', type=float, default=0.02) 116 | parser.add_argument('--limit-segments', type=int, default=None, 117 | help='Limit number of segments that are processed (used for testing)') 118 | args = parser.parse_args() 119 | 120 | jsonl_path = main(testset=args.testset, language_pair=args.language_pair, seed_no=args.seed, 121 | utility_name=args.utility, chrf_eps_smoothing=args.chrf_eps_smoothing, topk=args.topk, 122 | num_samples=args.num_samples, epsilon_cutoff=args.epsilon_cutoff, 123 | limit_segments=args.limit_segments, ) 124 | print(f"Saved results file to {jsonl_path}") 125 | -------------------------------------------------------------------------------- /experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines==4.0.0 2 | datasets==2.14.6 3 | sacrebleu==2.3.1 4 | sacremoses==0.0.53 # For OpusMT 5 | nltk==3.8.1 6 | rouge_score==0.1.2 7 | -------------------------------------------------------------------------------- /minimum-bayes-risk-decoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/minimum-bayes-risk-decoding.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mbr" 3 | version = "0.2.0" 4 | authors = [ 5 | { name="Jannis Vamvas", email="vamvas@cl.uzh.ch" }, 6 | ] 7 | description = "Minimum Bayes risk decoding for Hugging Face Transformers" 8 | readme = "README.md" 9 | requires-python = ">=3.9" 10 | dependencies = [ 11 | "transformers<4.39", 12 | "evaluate", 13 | "cachetools", 14 | "tqdm", 15 | "fastchrf", 16 | ] 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ] 22 | 23 | [project.urls] 24 | "Homepage" = "https://github.com/ZurichNLP/mbr" 25 | "Bug Tracker" = "https://github.com/ZurichNLP/mbr/issues" 26 | [build-system] 27 | requires = ["hatchling"] 28 | build-backend = "hatchling.build" 29 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | sacrebleu==2.4.0 2 | unbabel-comet==2.2.1 3 | git+https://github.com/google-research/bleurt.git 4 | sentencepiece==0.1.99 # M2M100 model 5 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | sacrebleu==2.4.0 2 | unbabel-comet==2.2.1 3 | -------------------------------------------------------------------------------- /src/mbr/__init__.py: -------------------------------------------------------------------------------- 1 | from mbr.generation.configuration_utils import MBRConfig 2 | from mbr.generation.utils import MBROutput, MBRGenerationMixin 3 | from mbr.metrics.base import MetricOutput, MetricRunner 4 | from mbr.modeling import MBR 5 | 6 | 7 | __version__ = "0.2.0" 8 | -------------------------------------------------------------------------------- /src/mbr/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/src/mbr/generation/__init__.py -------------------------------------------------------------------------------- /src/mbr/generation/configuration_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import __version__ as transformers_version 2 | from transformers.utils import logging 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | 7 | class MBRConfig: 8 | r""" 9 | Class that holds a configuration for minimum Bayes risk decoding (MBR). Pass this config when calling 10 | `MBRGenerationMixin.generate()`: 11 | 12 | Example: 13 | 14 | ```python 15 | >>> config = MBRConfig(num_samples=10, num_references=10, metric="fastchrf") 16 | >>> model.generate(..., mbr_config=config) 17 | ``` 18 | 19 | The class is inspired by `transformers.GenerationConfig`. 20 | Note that `MBRConfig` does not control the sampling strategy. Pass separate `GenerationConfig` objects to control 21 | sampling: 22 | 23 | ```python 24 | >>> generation_config = GenerationConfig(do_sample=True, num_beams=1, top_p=0.9) 25 | >>> references_config = GenerationConfig(do_sample=True, num_beams=1, epsilon_cutoff=0.02) 26 | >>> model.generate(..., mbr_config=config, generation_config=generation_config, references_config=references_config) 27 | ``` 28 | 29 | Arg: 30 | num_samples (`int`, *optional*, defaults to 10): 31 | Number of samples generated. 1 means no MBR decoding. 32 | num_references (`int`, *optional*, defaults to `num_samples`): 33 | Number of pseudo-references used for MBR decoding. 34 | metric (`str` or `~evaluate.Metric`, *optional*, defaults to 'fastchrf'): 35 | Metric used for MBR decoding. 36 | metric_config_name (`str`, *optional*, defaults to None): 37 | Metric configuration to pass to `evaluate.load` (e.g., the model for a trained metric, such as 38 | "eamt22-cometinho-da"). If not specified, the default configuration is used. 39 | metric_output_field (`str`, *optional*, defaults to 'score'): 40 | Field of the metric output that is used 41 | metric_kwargs (optional): 42 | Additional arguments for the metric's `compute` method. The default MetricRunner requires it to be hashable. 43 | metric_cache_size (`int`, *optional*, defaults to `num_samples` * `num_references`): 44 | Size of the cache for the metric. Set to `None` to disable caching (not recommended). 45 | lower_is_better (`bool`, *optional*, defaults to `False`): 46 | Set to true if lower metric scores are better (e.g., perplexity). 47 | 48 | > Parameters that define the output variables of `generate` 49 | 50 | output_attentions (`bool`, *optional*, defaults to `False`): 51 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 52 | tensors for more details. 53 | output_hidden_states (`bool`, *optional*, defaults to `False`): 54 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 55 | more details. 56 | output_scores (`bool`, *optional*, defaults to `False`): 57 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 58 | output_all_samples (`bool`, *optional*, defaults to `False`): 59 | Whether or not to return all sampled sequences. See `all_sampled_sequences` under returned tensors for more 60 | details. 61 | output_reference_sequences (`bool`, *optional*, defaults to `False`): 62 | Whether or not to return the reference sequences. See `reference_sequences` under returned tensors for more 63 | details. 64 | output_metric_scores (`bool`, *optional*, defaults to `False`): 65 | Whether or not to return the metric scores. See `metric_scores` under returned tensors for more details. 66 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 67 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 68 | """ 69 | 70 | def __init__(self, **kwargs): 71 | # Parameters that control the generation strategy used 72 | self.num_samples = kwargs.pop("num_samples", 10) 73 | self.num_references = kwargs.pop("num_references", self.num_samples) 74 | self.metric = kwargs.pop("metric", "fastchrf") 75 | self.metric_config_name = kwargs.pop("metric_config_name", None) 76 | self.metric_output_field = kwargs.pop("metric_output_field", "score") 77 | self.metric_kwargs = kwargs.pop("metric_kwargs", {}) 78 | self.metric_cache_size = kwargs.pop("metric_cache_size", self.num_samples * self.num_references) 79 | self.lower_is_better = kwargs.pop("lower_is_better", False) 80 | 81 | # Parameters that define the output variables of `generate` 82 | self.output_attentions = kwargs.pop("output_attentions", False) 83 | self.output_hidden_states = kwargs.pop("output_hidden_states", False) 84 | self.output_scores = kwargs.pop("output_scores", False) 85 | self.output_all_samples = kwargs.pop("output_all_samples", False) 86 | self.output_reference_sequences = kwargs.pop("output_reference_sequences", False) 87 | self.output_metric_scores = kwargs.pop("output_metric_scores", False) 88 | self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) 89 | 90 | # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub 91 | # interface. 92 | self._from_model_config = kwargs.pop("_from_model_config", False) 93 | self._commit_hash = kwargs.pop("_commit_hash", None) 94 | self.transformers_version = kwargs.pop("transformers_version", transformers_version) 95 | import mbr 96 | self.mbr_version = kwargs.pop("mbr_version", mbr.__version__) 97 | 98 | # Additional attributes without default values 99 | if not self._from_model_config: 100 | # we don't want to copy values from the model config if we're initializing an `MBRConfig` from a 101 | # model's default configuration file 102 | for key, value in kwargs.items(): 103 | try: 104 | setattr(self, key, value) 105 | except AttributeError as err: 106 | logger.error(f"Can't set {key} with value {value} for {self}") 107 | raise err 108 | 109 | # Validate the values of the attributes 110 | self.validate(is_init=True) 111 | 112 | def validate(self, is_init=False): 113 | """ 114 | Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence 115 | of parameterization that can be detected as incorrect from the configuration instance alone. 116 | 117 | Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the 118 | model, such as parameters related to the generation length. 119 | """ 120 | if self.metric_cache_size <= 0: 121 | raise ValueError(f"`metric_cache_size` ({self.metric_cache_size}) must be greater than 0.") 122 | -------------------------------------------------------------------------------- /src/mbr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from mbr import MBRConfig 2 | from mbr.metrics.base import metric_is_source_based, MetricRunner 3 | 4 | 5 | def load_metric_runner(mbr_config: MBRConfig, tokenizer=None) -> MetricRunner: 6 | if mbr_config.metric in {"fastchrf", "aggregate_chrf", "fastchrf.aggregate_chrf"}: 7 | from mbr.metrics.fastchrf import FastChrfMetricRunner 8 | return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=False) 9 | elif mbr_config.metric in {"pairwise_chrf", "fastchrf.pairwise_chrf"}: 10 | from mbr.metrics.fastchrf import FastChrfMetricRunner 11 | return FastChrfMetricRunner(mbr_config, tokenizer, compute_pairwise_average=True) 12 | elif mbr_config.metric == "comet": 13 | from mbr.metrics.comet import CometMetricRunner 14 | return CometMetricRunner(mbr_config, tokenizer, 15 | device=0, 16 | batch_size_embed=64, 17 | batch_size_estimate=64, 18 | progress_bar=True, 19 | ) 20 | elif mbr_config.metric == "aggregate_comet": 21 | from mbr.metrics.comet import AggregateCometMetricRunner 22 | mbr_config.metric = "comet" 23 | return AggregateCometMetricRunner(mbr_config, tokenizer, 24 | device=0, 25 | batch_size_embed=64, 26 | batch_size_estimate=64, 27 | progress_bar=True, 28 | ) 29 | else: 30 | return MetricRunner(mbr_config, tokenizer) 31 | -------------------------------------------------------------------------------- /src/mbr/metrics/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple, Union, List, Optional 3 | 4 | import evaluate 5 | import torch 6 | from cachetools.func import fifo_cache 7 | from datasets import Metric 8 | from evaluate import EvaluationModule 9 | from transformers import PreTrainedTokenizerBase 10 | from transformers.utils import ModelOutput 11 | 12 | from mbr import MBRConfig 13 | 14 | MetricType = Union[Metric, EvaluationModule] 15 | 16 | 17 | @dataclass 18 | class MetricOutput(ModelOutput): 19 | """ 20 | Args: 21 | scores (`torch.FloatTensor` of shape `(batch_size, num_samples)`): 22 | The metric scores for each sample (aggregated over all references). 23 | scores_per_reference (`torch.FloatTensor` of shape `(batch_size, num_samples, num_references)`): 24 | The pairwise metric scores for each sample and reference. `None` if the metric is computed corpus-level. 25 | """ 26 | scores: torch.FloatTensor 27 | scores_per_reference: Optional[torch.FloatTensor] = None 28 | 29 | 30 | class MetricRunner: 31 | """ 32 | Applies the metric to samples and references (and optionally inputs) and calculates a metric score for each sample. 33 | This implementation uses the most basic approach, where samples and references are compared pairwise. 34 | Some metrics may support multi-reference evaluation or batching. Consider creating a subclass to make use of these 35 | features. 36 | """ 37 | 38 | def __init__(self, mbr_config: MBRConfig, tokenizer: PreTrainedTokenizerBase): 39 | self.mbr_config = mbr_config 40 | # Ensure that mbr_config.metric_kwargs is hashable (because _compute_metric() uses lru_cache) 41 | if mbr_config.metric_kwargs: 42 | try: 43 | hash(tuple(self.mbr_config.metric_kwargs)) 44 | except TypeError as e: 45 | raise TypeError(f"mbr_config.metric_kwargs must be hashable.") from e 46 | self.tokenizer = tokenizer 47 | self.metric = self._load_metric() 48 | self.metric_is_source_based = metric_is_source_based(self.metric) 49 | self._compute_metric_cached = fifo_cache(maxsize=self.mbr_config.metric_cache_size)(self._compute_metric) 50 | 51 | def _load_metric(self) -> MetricType: 52 | metric = self.mbr_config.metric 53 | if isinstance(metric, EvaluationModule): 54 | return metric 55 | elif isinstance(metric, str): 56 | metric = evaluate.load(metric, self.mbr_config.metric_config_name) 57 | else: 58 | raise ValueError(f"Invalid metric type: {type(metric)}") 59 | if metric.name == "comet": 60 | metric.scorer.eval() 61 | return metric 62 | 63 | def __call__(self, 64 | input_ids: torch.LongTensor, 65 | sample_ids: Tuple[torch.LongTensor], 66 | reference_ids: Tuple[torch.LongTensor], 67 | ) -> MetricOutput: 68 | r""" 69 | Args: 70 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 71 | The input sequence ids. 72 | sample_ids (`tuple(torch.LongTensor)`): 73 | Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing 74 | the sampled sequences. 75 | reference_ids: 76 | Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing 77 | the reference sequences. 78 | 79 | Returns: 80 | `MetricOutput` containing the metric scores. 81 | """ 82 | 83 | # Detokenize 84 | if self.metric_is_source_based: 85 | str_inputs = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) # shape: (batch_size,) 86 | else: 87 | str_inputs = None 88 | str_samples = [] # num_samples x batch_size 89 | for sample in sample_ids: 90 | str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True)) 91 | str_references = [] # num_references x batch_size 92 | for reference in reference_ids: 93 | str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True)) 94 | 95 | if len(str_samples[0]) != len(str_references[0]): 96 | raise ValueError("Batch size of samples and references must match") 97 | if len(str_samples) != self.mbr_config.num_samples: 98 | raise ValueError("Number of samples must match `mbr_config.num_samples`") 99 | if len(str_references) != self.mbr_config.num_references: 100 | raise ValueError("Number of references must match `mbr_config.num_references`") 101 | 102 | # Compute metric 103 | scores_per_reference = self._compute_str_metric(str_samples, str_references, str_inputs) 104 | 105 | return MetricOutput( 106 | scores=scores_per_reference.mean(dim=-1), 107 | scores_per_reference=scores_per_reference, 108 | ) 109 | 110 | def _compute_str_metric(self, 111 | samples: List[List[str]], 112 | references: List[List[str]], 113 | inputs: List[str] = None, 114 | ) -> torch.FloatTensor: 115 | batch_size = len(samples[0]) 116 | metric_scores = torch.zeros((batch_size, len(samples), len(references))) 117 | for i in range(batch_size): 118 | for j in range(len(samples)): 119 | sample = samples[j][i] 120 | for k in range(len(references)): 121 | reference = references[k][i] 122 | if inputs is not None: 123 | score = self.compute_metric( 124 | sources=(inputs[i],), 125 | predictions=(sample,), 126 | references=(reference,), 127 | **self.mbr_config.metric_kwargs, 128 | ) 129 | else: 130 | score = self.compute_metric( 131 | predictions=(sample,), 132 | references=(reference,), 133 | **self.mbr_config.metric_kwargs, 134 | ) 135 | metric_scores[i, j, k] = score 136 | return metric_scores 137 | 138 | def _compute_metric(self, *args, **kwargs) -> float: 139 | # Call _compute() instead of compute() for performance reasons. 140 | # Since we are comparing individual samples, we do not need the overhead of compute(). 141 | output = self.metric._compute(*args, **kwargs) 142 | if self.mbr_config.metric_output_field not in output: 143 | raise ValueError(f"Metric output does not contain '{self.mbr_config.metric_output_field}' " 144 | f"Use `mbr_config.metric_output_field` to specify the correct field. " 145 | f"Available fields: {list(output.keys())}" 146 | ) 147 | score = output[self.mbr_config.metric_output_field] 148 | if isinstance(score, list): 149 | score = score[0] 150 | return score 151 | 152 | def compute_metric(self, *args, **kwargs) -> float: 153 | return self._compute_metric_cached(*args, **kwargs) 154 | 155 | 156 | def metric_is_source_based(metric: MetricType) -> bool: 157 | return "sources" in metric.features 158 | -------------------------------------------------------------------------------- /src/mbr/metrics/comet.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import List, Tuple, Dict, Set 3 | 4 | import torch 5 | from cachetools import FIFOCache 6 | from comet.models import RegressionMetric 7 | from tqdm import tqdm 8 | 9 | from mbr import MetricRunner, MetricOutput 10 | 11 | 12 | class CometMetricRunner(MetricRunner): 13 | """ 14 | Efficient usage of COMET for MBR, based on https://github.com/Unbabel/COMET 15 | 16 | The implementation is inspired by https://github.com/chanberg/COMET-mbr and 17 | https://github.com/Unbabel/COMET/blob/master/comet/cli/mbr.py 18 | """ 19 | 20 | def __init__(self, 21 | *args, 22 | device=None, 23 | batch_size_embed: int = 1, 24 | batch_size_estimate: int = 1, 25 | progress_bar: bool = False, 26 | **kwargs, 27 | ): 28 | super().__init__(*args, **kwargs) 29 | if self.metric.__class__.__name__ != "COMET": 30 | raise ValueError( 31 | f"CometMetricRunner expects an evaluate.COMET metric, got {self.metric.__class__.__name__}") 32 | self.comet = self.metric 33 | if self.mbr_config.metric_output_field not in ["mean_score", "scores"]: 34 | raise ValueError(f"CometMetricRunner expects metric_output_field to be 'mean_score' or 'scores', " 35 | f"got {self.mbr_config.metric_output_field}") 36 | if self.mbr_config.metric_kwargs: 37 | raise NotImplementedError("CometMetricRunner does not support metric_kwargs") 38 | if not isinstance(self.comet.scorer, RegressionMetric): 39 | raise NotImplementedError("CometMetricRunner only supports COMET models that are an instance of " 40 | "comet.models.RegressionMetric") 41 | if device is not None: 42 | self.comet.scorer = self.comet.scorer.to(device) 43 | self.comet.scorer.eval() 44 | self.batch_size_embed = batch_size_embed 45 | self.batch_size_estimate = batch_size_estimate 46 | self.progress_bar = progress_bar 47 | # We use a key-value cache, which is needed if the metric is called multiple times with similar inputs 48 | # (e.g. for MBR with iterative pruning). 49 | self.embedding_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size) 50 | self.score_cache = FIFOCache(maxsize=self.mbr_config.metric_cache_size) 51 | 52 | @torch.no_grad() 53 | def _compute_str_metric(self, 54 | samples: List[List[str]], 55 | references: List[List[str]], 56 | inputs: List[str] = None, 57 | ) -> torch.FloatTensor: 58 | if inputs is None: 59 | raise NotImplementedError("CometMetricRunner requires source sequences (`inputs`) to be provided") 60 | batch_size = len(samples[0]) 61 | metric_scores = torch.zeros((batch_size, len(samples), len(references))) 62 | for i in tqdm(list(range(batch_size)), desc="comet", disable=not self.progress_bar): 63 | # Embed all sequences 64 | all_samples = [sample[i] for sample in samples] 65 | all_references = [reference[i] for reference in references] 66 | all_sequences = set(all_samples + all_references + inputs) 67 | 68 | all_embeddings: Dict[str, torch.FloatTensor] = {} 69 | # Populate embeddings from cache 70 | for sequence in list(all_sequences): 71 | if sequence in self.embedding_cache: 72 | all_embeddings[sequence] = self.embedding_cache[sequence] 73 | all_sequences.remove(sequence) 74 | 75 | # Compute embeddings for remaining sequences 76 | if all_sequences: 77 | all_sequences = list(all_sequences) 78 | encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device) 79 | batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed), 80 | range(self.batch_size_embed, len(all_sequences), self.batch_size_embed)) 81 | for start_idx, end_idx in batches: 82 | embeddings = self.comet.scorer.get_sentence_embedding( 83 | input_ids=encodings["input_ids"][start_idx:end_idx], 84 | attention_mask=encodings["attention_mask"][start_idx:end_idx], 85 | ) 86 | for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)): 87 | embedding = embeddings[j - start_idx] 88 | all_embeddings[all_sequences[j]] = embedding 89 | self.embedding_cache[all_sequences[j]] = embedding 90 | 91 | # Collect all input triples in a list 92 | input_triples: Set[Tuple[str, str, str]] = set() 93 | for j in range(len(samples)): 94 | for k in range(len(references)): 95 | input_triples.add((inputs[i], samples[j][i], references[k][i])) 96 | 97 | input_triple_scores: Dict[Tuple[str, str, str], torch.FloatTensor] = {} 98 | # Populate scores from cache 99 | for triple in list(input_triples): 100 | if triple in self.score_cache: 101 | input_triple_scores[triple] = self.score_cache[triple] 102 | input_triples.remove(triple) 103 | 104 | # Compute scores for remaining input triples 105 | input_triples: List = list(input_triples) 106 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate), 107 | range(self.batch_size_estimate, len(input_triples), 108 | self.batch_size_estimate)) 109 | for start_idx, end_idx in batches: 110 | batch = input_triples[start_idx:end_idx] 111 | batch_scores = self.comet.scorer.estimate( 112 | src_sentemb=torch.stack([all_embeddings[triple[0]] for triple in batch]), 113 | mt_sentemb=torch.stack([all_embeddings[triple[1]] for triple in batch]), 114 | ref_sentemb=torch.stack([all_embeddings[triple[2]] for triple in batch]), 115 | ) 116 | for j in range(start_idx, end_idx if end_idx is not None else len(input_triples)): 117 | triple = batch[j - start_idx] 118 | score = batch_scores.score[j - start_idx] 119 | input_triple_scores[triple] = score 120 | self.score_cache[triple] = score 121 | 122 | for j in range(len(samples)): 123 | for k in range(len(references)): 124 | metric_scores[i, j, k] = input_triple_scores[(inputs[i], samples[j][i], references[k][i])] 125 | 126 | return metric_scores 127 | 128 | 129 | class AggregateCometMetricRunner(CometMetricRunner): 130 | """ 131 | Implements reference aggregation as described in "Linear-time Minimum Bayes Risk Decoding with Reference Aggregation" 132 | (Vamvas & Sennrich, 2024) https://arxiv.org/abs/2402.04251 133 | """ 134 | 135 | def __call__(self, 136 | input_ids: torch.LongTensor, 137 | sample_ids: Tuple[torch.LongTensor], 138 | reference_ids: Tuple[torch.LongTensor], 139 | ) -> MetricOutput: 140 | r""" 141 | Args: 142 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 143 | The input sequence ids. 144 | sample_ids (`tuple(torch.LongTensor)`): 145 | Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing 146 | the sampled sequences. 147 | reference_ids: 148 | Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing 149 | the reference sequences. 150 | 151 | Returns: 152 | `MetricOutput` containing the metric scores. 153 | """ 154 | 155 | # Detokenize 156 | if self.metric_is_source_based: 157 | str_inputs = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) # shape: (batch_size,) 158 | else: 159 | str_inputs = None 160 | str_samples = [] # num_samples x batch_size 161 | for sample in sample_ids: 162 | str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True)) 163 | str_references = [] # num_references x batch_size 164 | for reference in reference_ids: 165 | str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True)) 166 | 167 | if len(str_samples[0]) != len(str_references[0]): 168 | raise ValueError("Batch size of samples and references must match") 169 | 170 | # Compute metric 171 | scores = self._compute_str_metric(str_samples, str_references, str_inputs) 172 | 173 | return MetricOutput( 174 | scores=scores, 175 | scores_per_reference=None, 176 | ) 177 | 178 | @torch.no_grad() 179 | def _compute_str_metric(self, 180 | samples: List[List[str]], 181 | references: List[List[str]], 182 | inputs: List[str] = None, 183 | ) -> torch.FloatTensor: 184 | if inputs is None: 185 | raise NotImplementedError("CometMetricRunner requires source sequences (`inputs`) to be provided") 186 | batch_size = len(samples[0]) 187 | metric_scores = torch.zeros((batch_size, len(samples))) 188 | for i in tqdm(list(range(batch_size)), desc="comet", disable=not self.progress_bar): 189 | # Embed all sequences 190 | all_samples = [sample[i] for sample in samples] 191 | all_references = [reference[i] for reference in references] 192 | all_sequences = set(all_samples + all_references + inputs) 193 | 194 | all_embeddings: Dict[str, torch.FloatTensor] = {} 195 | # Populate embeddings from cache 196 | for sequence in list(all_sequences): 197 | if sequence in self.embedding_cache: 198 | all_embeddings[sequence] = self.embedding_cache[sequence] 199 | all_sequences.remove(sequence) 200 | 201 | # Compute embeddings for remaining sequences 202 | if all_sequences: 203 | all_sequences = list(all_sequences) 204 | encodings = self.comet.scorer.encoder.prepare_sample(all_sequences).to(self.comet.scorer.device) 205 | batches = itertools.zip_longest(range(0, len(all_sequences), self.batch_size_embed), 206 | range(self.batch_size_embed, len(all_sequences), self.batch_size_embed)) 207 | for start_idx, end_idx in batches: 208 | embeddings = self.comet.scorer.get_sentence_embedding( 209 | input_ids=encodings["input_ids"][start_idx:end_idx], 210 | attention_mask=encodings["attention_mask"][start_idx:end_idx], 211 | ) 212 | for j in range(start_idx, end_idx if end_idx is not None else len(all_sequences)): 213 | embedding = embeddings[j - start_idx] 214 | all_embeddings[all_sequences[j]] = embedding 215 | self.embedding_cache[all_sequences[j]] = embedding 216 | 217 | # Compute average reference embedding 218 | avg_reference_embedding = torch.stack([all_embeddings[reference] for reference in all_references]).mean(dim=0) 219 | 220 | # Collect all input triples in a list 221 | input_triples: Set[Tuple[str, str, str]] = set() 222 | for j in range(len(samples)): 223 | input_triples.add((inputs[i], samples[j][i], "avg")) 224 | 225 | input_triple_scores: Dict[Tuple[str, str, str], torch.FloatTensor] = {} 226 | # Populate scores from cache 227 | for triple in list(input_triples): 228 | if triple in self.score_cache: 229 | input_triple_scores[triple] = self.score_cache[triple] 230 | input_triples.remove(triple) 231 | 232 | # Compute scores for remaining input triples 233 | input_triples: List = list(input_triples) 234 | batches = itertools.zip_longest(range(0, len(input_triples), self.batch_size_estimate), 235 | range(self.batch_size_estimate, len(input_triples), 236 | self.batch_size_estimate)) 237 | for start_idx, end_idx in batches: 238 | batch = input_triples[start_idx:end_idx] 239 | batch_scores = self.comet.scorer.estimate( 240 | src_sentemb=torch.stack([all_embeddings[triple[0]] for triple in batch]), 241 | mt_sentemb=torch.stack([all_embeddings[triple[1]] for triple in batch]), 242 | ref_sentemb=avg_reference_embedding.unsqueeze(0).repeat(len(batch), 1), 243 | ) 244 | for j in range(start_idx, end_idx if end_idx is not None else len(input_triples)): 245 | triple = batch[j - start_idx] 246 | score = batch_scores.score[j - start_idx] 247 | input_triple_scores[triple] = score 248 | self.score_cache[triple] = score 249 | 250 | for j in range(len(samples)): 251 | metric_scores[i, j] = input_triple_scores[(inputs[i], samples[j][i], "avg")] 252 | 253 | return metric_scores 254 | -------------------------------------------------------------------------------- /src/mbr/metrics/fastchrf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from fastchrf import pairwise_chrf, aggregate_chrf 5 | from transformers import PreTrainedTokenizerBase 6 | 7 | from mbr import MetricRunner, MBRConfig, MetricOutput 8 | 9 | 10 | class FastChrfMetricRunner(MetricRunner): 11 | """ 12 | MetricRunner for fastChrF. See https://github.com/jvamvas/fastChrF for more information. 13 | 14 | Args: 15 | mbr_config 16 | tokenizer 17 | compute_pairwise_average: Default: False. If True, use fastchr.chrf_pairwise() to calculate exact ChrF scores 18 | for each sample-reference pair and then average them; this corresponds to a fast implementation of the 19 | original ChrF metric. If False, use fastchr.chrf_aggregate() to directly calculate aggregate fastChrF scores 20 | across all references; note that the result will be different from the original ChrF metric. 21 | """ 22 | 23 | def __init__(self, 24 | mbr_config: MBRConfig, 25 | tokenizer: PreTrainedTokenizerBase, 26 | compute_pairwise_average: bool = False, 27 | ): 28 | self.mbr_config = mbr_config 29 | self.tokenizer = tokenizer 30 | self.metric_is_source_based = False 31 | self.char_order = mbr_config.metric_kwargs.get("char_order", 6) 32 | self.beta = mbr_config.metric_kwargs.get("beta", 2) 33 | self.remove_whitespace = mbr_config.metric_kwargs.get("remove_whitespace", True) 34 | self.eps_smoothing = mbr_config.metric_kwargs.get("eps_smoothing", False) 35 | self.compute_pairwise_average = compute_pairwise_average 36 | 37 | def __call__(self, 38 | input_ids: torch.LongTensor, 39 | sample_ids: Tuple[torch.LongTensor], 40 | reference_ids: Tuple[torch.LongTensor], 41 | ) -> MetricOutput: 42 | r""" 43 | Args: 44 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 45 | The input sequence ids. 46 | sample_ids (`tuple(torch.LongTensor)`): 47 | Tuple (one element for `num_samples`) of tensors of shape `(batch_size, sequence_length)` containing 48 | the sampled sequences. 49 | reference_ids: 50 | Tuple (one element for `num_references`) of tensors of shape `(batch_size, sequence_length)` containing 51 | the reference sequences. 52 | 53 | Returns: 54 | `MetricOutput` containing the metric scores. 55 | """ 56 | 57 | # Detokenize 58 | str_samples = [] # num_samples x batch_size 59 | for sample in sample_ids: 60 | str_samples.append(self.tokenizer.batch_decode(sample, skip_special_tokens=True)) 61 | str_references = [] # num_references x batch_size 62 | for reference in reference_ids: 63 | str_references.append(self.tokenizer.batch_decode(reference, skip_special_tokens=True)) 64 | 65 | if len(str_samples[0]) != len(str_references[0]): 66 | raise ValueError("Batch size of samples and references must match") 67 | if len(str_samples) != self.mbr_config.num_samples: 68 | raise ValueError("Number of samples must match `mbr_config.num_samples`") 69 | if len(str_references) != self.mbr_config.num_references: 70 | raise ValueError("Number of references must match `mbr_config.num_references`") 71 | 72 | # Transpose to batch_size x num_samples/num_references 73 | str_samples = list(zip(*str_samples)) 74 | str_references = list(zip(*str_references)) 75 | 76 | if self.compute_pairwise_average: 77 | output = self._compute_pairwise_chrf(str_samples, str_references) 78 | else: 79 | output = self._compute_aggregate_chrf(str_samples, str_references) 80 | return output 81 | 82 | def _compute_pairwise_chrf(self, samples: List[List[str]], references: List[List[str]]) -> MetricOutput: 83 | scores_per_reference = pairwise_chrf( 84 | samples, 85 | references, 86 | char_order=self.char_order, 87 | beta=self.beta, 88 | remove_whitespace=self.remove_whitespace, 89 | eps_smoothing=self.eps_smoothing, 90 | ) 91 | scores_per_reference = torch.tensor(scores_per_reference) 92 | scores = scores_per_reference.mean(dim=-1) 93 | return MetricOutput( 94 | scores=scores, 95 | scores_per_reference=scores_per_reference, 96 | ) 97 | 98 | def _compute_aggregate_chrf(self, samples: List[List[str]], references: List[List[str]]) -> MetricOutput: 99 | scores = aggregate_chrf( 100 | samples, 101 | references, 102 | char_order=self.char_order, 103 | beta=self.beta, 104 | remove_whitespace=self.remove_whitespace, 105 | eps_smoothing=self.eps_smoothing, 106 | ) 107 | scores = torch.tensor(scores) 108 | return MetricOutput( 109 | scores=scores, 110 | scores_per_reference=None, 111 | ) 112 | -------------------------------------------------------------------------------- /src/mbr/modeling.py: -------------------------------------------------------------------------------- 1 | from transformers import GenerationMixin 2 | 3 | from mbr.generation.utils import MBRGenerationMixin 4 | 5 | 6 | def MBR(model_class: type) -> type: 7 | """ 8 | Utility function for converting a model class into a class that inherits from `~generation.MBRGenerationMixin`. 9 | """ 10 | if not issubclass(model_class, GenerationMixin): 11 | raise ValueError( 12 | f"MBR() can only be applied to classes that inherit from `transformers.GenerationMixin`, " 13 | f"but got {model_class}." 14 | ) 15 | return type("MBR" + model_class.__name__, (MBRGenerationMixin, model_class), {}) 16 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZurichNLP/mbr/84b4c6b0d2fa2974d0a717f5729e02612f2e9bcd/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from mbr import MBRConfig 4 | 5 | 6 | class MBRConfigTestCase(TestCase): 7 | 8 | def test_default_config(self): 9 | config = MBRConfig() 10 | self.assertEqual(config.num_samples, 10) 11 | self.assertEqual(config.num_references, 10) 12 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from unittest import TestCase 4 | 5 | import evaluate 6 | import torch 7 | from transformers import AutoTokenizer 8 | 9 | from mbr import MetricRunner, MBRConfig 10 | from mbr.metrics import metric_is_source_based 11 | 12 | 13 | class MetricUtilsTestCase(TestCase): 14 | 15 | def setUp(self): 16 | self.mbr_config = MBRConfig( 17 | metric="chrf", 18 | metric_output_field="score", 19 | num_samples=3, 20 | num_references=2, 21 | ) 22 | self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") 23 | self.tokenizer.pad_token = self.tokenizer.eos_token 24 | self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) 25 | self.inputs = [ # shape: (batch_size,) 26 | "This is an input sentence.", 27 | "This is another input sentence.", 28 | ] 29 | self.samples = [ # num_samples x batch_size 30 | ["This is a sample sentence.", "Something totally different."], 31 | ["This is a sample sentence.", "This a third sample sentence."], 32 | ["Something totally different.", "This is a fourth sample sentence."], 33 | ] 34 | self.references = [ # num_references x batch_size 35 | ["This is a reference sentence.", "This is another reference sentence."], 36 | ["This is a reference sentence.", "This is a fourth reference sentence."], 37 | ] 38 | self.input_ids = self.tokenizer(self.inputs, return_tensors="pt", padding=True).input_ids 39 | self.sample_ids = tuple([self.tokenizer(sample, return_tensors="pt", padding=True).input_ids for sample in self.samples]) 40 | self.reference_ids = tuple([self.tokenizer(reference, return_tensors="pt", padding=True).input_ids for reference in self.references]) 41 | 42 | def test_is_source_based__chrf(self): 43 | chrf = evaluate.load("chrf") 44 | self.assertFalse(metric_is_source_based(chrf)) 45 | 46 | def test_is_source_based__comet(self): 47 | comet = evaluate.load("comet", "eamt22-cometinho-da") 48 | self.assertTrue(metric_is_source_based(comet)) 49 | 50 | @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") 51 | def test_is_source_based__bleurt(self): 52 | bleurt = evaluate.load("bleurt") 53 | self.assertFalse(metric_is_source_based(bleurt)) 54 | 55 | def test_load_metric(self): 56 | self.mbr_config.metric = "chrf" 57 | metric = self.metric_runner._load_metric() 58 | self.assertIsInstance(metric, evaluate.Metric) 59 | self.assertEqual(metric.name, "chr_f") 60 | self.mbr_config.metric = evaluate.load("chrf") 61 | metric = self.metric_runner._load_metric() 62 | self.assertIsInstance(metric, evaluate.Metric) 63 | self.assertEqual(metric.name, "chr_f") 64 | 65 | def test_metric_config_name(self): 66 | self.mbr_config.metric = "comet" 67 | self.mbr_config.metric_config_name = "eamt22-cometinho-da" 68 | self.mbr_config.metric_output_field = "mean_score" 69 | metric = self.metric_runner._load_metric() 70 | self.assertIsInstance(metric, evaluate.Metric) 71 | self.assertEqual(metric.name, "comet") 72 | # Test custom metric_config_name 73 | self.assertEqual(metric.scorer.encoder.__class__.__name__, "MiniLMEncoder") 74 | 75 | def test_compute_metric__chrf(self): 76 | metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 77 | self.assertTrue(torch.is_floating_point(metric_output.scores)) 78 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) 79 | torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) 80 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples 81 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references 82 | # Duplicate samples should have the same scores 83 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) 84 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) 85 | # The metric scores should rank as expected, given the test strings in self.samples and self.references 86 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) 87 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) 88 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) 89 | 90 | def test_compute_metric__comet(self): 91 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") 92 | self.mbr_config.metric.scorer.eval() 93 | self.mbr_config.metric_output_field = "mean_score" 94 | self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) 95 | self.assertEqual(self.metric_runner.metric.name, "comet") 96 | metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 97 | self.assertTrue(torch.is_floating_point(metric_output.scores)) 98 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) 99 | torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) 100 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples 101 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references 102 | # Duplicate samples should have the same scores 103 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) 104 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) 105 | # The metric scores should rank as expected, given the test strings in self.samples and self.references 106 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) 107 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) 108 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) 109 | 110 | @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") 111 | def test_compute_metric__bleurt(self): 112 | self.mbr_config.metric = evaluate.load("bleurt") 113 | self.mbr_config.metric_output_field = "scores" 114 | self.metric_runner = MetricRunner(self.mbr_config, self.tokenizer) 115 | self.assertEqual(self.metric_runner.metric.name, "bleurt") 116 | metric_output = self.metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 117 | self.assertTrue(torch.is_floating_point(metric_output.scores)) 118 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) 119 | torch.testing.assert_close(metric_output.scores_per_reference.mean(dim=-1), metric_output.scores) 120 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples 121 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references 122 | # Duplicate samples should have the same scores 123 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) 124 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) 125 | # The metric scores should rank as expected, given the test strings in self.samples and self.references 126 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) 127 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) 128 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) 129 | 130 | def test_comet_metric_runner(self): 131 | from mbr.metrics.comet import CometMetricRunner 132 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") 133 | self.mbr_config.metric.scorer.eval() 134 | self.mbr_config.metric_output_field = "mean_score" 135 | base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) 136 | self.assertEqual(base_metric_runner.metric.name, "comet") 137 | self.assertFalse(base_metric_runner.metric.scorer.training) 138 | comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) 139 | self.assertFalse(comet_metric_runner.metric.scorer.training) 140 | # Output should be the same as the base MetricRunner 141 | base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 142 | metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 143 | torch.testing.assert_close(base_metric_scores, metric_scores) 144 | 145 | def test_comet_metric_runner__cache(self): 146 | """Output should be identical irrespective of cache size""" 147 | from mbr.metrics.comet import CometMetricRunner 148 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") 149 | self.mbr_config.metric_output_field = "mean_score" 150 | base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) 151 | base_metric_scores = base_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 152 | self.assertEqual(base_metric_runner.metric.name, "comet") 153 | for cache_size in [1, 4, 8]: 154 | self.mbr_config.metric_cache_size = cache_size 155 | comet_metric_runner = CometMetricRunner(self.mbr_config, self.tokenizer) 156 | metric_scores = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 157 | torch.testing.assert_close(base_metric_scores, metric_scores) 158 | 159 | def test_comet_metric_runner__aggregate(self): 160 | from mbr.metrics.comet import AggregateCometMetricRunner 161 | self.mbr_config.metric = evaluate.load("comet", "eamt22-cometinho-da") 162 | self.mbr_config.metric.scorer.eval() 163 | self.mbr_config.metric_output_field = "mean_score" 164 | base_metric_runner = MetricRunner(self.mbr_config, self.tokenizer) 165 | self.assertEqual(base_metric_runner.metric.name, "comet") 166 | self.assertFalse(base_metric_runner.metric.scorer.training) 167 | comet_metric_runner = AggregateCometMetricRunner(self.mbr_config, self.tokenizer) 168 | self.assertFalse(comet_metric_runner.metric.scorer.training) 169 | metric_output = comet_metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 170 | self.assertTrue(torch.is_floating_point(metric_output.scores)) 171 | self.assertIsNone(metric_output.scores_per_reference) 172 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples 173 | # Duplicate samples should have the same scores 174 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) 175 | # The metric scores should rank as expected, given the test strings in self.samples and self.references 176 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) 177 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) 178 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) 179 | 180 | def test_fastchrf_metric_runner__aggregate(self): 181 | from mbr.metrics.fastchrf import FastChrfMetricRunner 182 | metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=False) 183 | metric_output = metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 184 | self.assertTrue(torch.is_floating_point(metric_output.scores)) 185 | self.assertIsNone(metric_output.scores_per_reference) 186 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples 187 | # Duplicate samples should have the same scores 188 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) 189 | # The metric scores should rank as expected, given the test strings in self.samples and self.references 190 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) 191 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) 192 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) 193 | 194 | def test_fastchrf_metric_runner__pairwise(self): 195 | from mbr.metrics.fastchrf import FastChrfMetricRunner 196 | metric_runner = FastChrfMetricRunner(self.mbr_config, self.tokenizer, compute_pairwise_average=True) 197 | metric_output = metric_runner(self.input_ids, self.sample_ids, self.reference_ids) 198 | self.assertTrue(torch.is_floating_point(metric_output.scores)) 199 | self.assertTrue(torch.is_floating_point(metric_output.scores_per_reference)) 200 | self.assertEqual(metric_output.scores.shape, (2, 3)) # batch_size x num_samples 201 | self.assertEqual(metric_output.scores_per_reference.shape, (2, 3, 2)) # batch_size x num_samples x num_references 202 | # Duplicate samples should have the same scores 203 | torch.testing.assert_close(metric_output.scores[0, 0], metric_output.scores[0, 1]) 204 | torch.testing.assert_close(metric_output.scores_per_reference[0, 0, 0], metric_output.scores_per_reference[0, 1, 0]) 205 | # The metric scores should rank as expected, given the test strings in self.samples and self.references 206 | self.assertGreater(metric_output.scores[0, 0], metric_output.scores[0, 2]) 207 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 1]) 208 | self.assertLess(metric_output.scores[1, 0], metric_output.scores[1, 2]) 209 | -------------------------------------------------------------------------------- /tests/test_pipelines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from unittest import TestCase 4 | 5 | from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel, M2M100ForConditionalGeneration, set_seed 6 | 7 | from mbr import MBRConfig 8 | from mbr import MBR 9 | 10 | 11 | class TextGenerationTestCase(TestCase): 12 | 13 | def setUp(self): 14 | set_seed(42) 15 | self.model = MBR(GPT2LMHeadModel).from_pretrained("distilgpt2").eval() 16 | self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") 17 | self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer) 18 | 19 | def test_pipeline(self): 20 | mbr_config = MBRConfig( 21 | num_samples=5, 22 | ) 23 | output = self.pipeline( 24 | "Hello,", 25 | mbr_config=mbr_config, 26 | tokenizer=self.tokenizer, 27 | ) 28 | self.assertEqual(1, len(output)) 29 | self.assertIn("generated_text", output[0]) 30 | 31 | 32 | @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies") 33 | class TranslationTestCase(TestCase): 34 | 35 | def setUp(self): 36 | set_seed(42) 37 | self.model = MBR(M2M100ForConditionalGeneration).from_pretrained("alirezamsh/small100").eval() 38 | self.tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100") 39 | self.pipeline = pipeline("translation_en_to_fr", model=self.model, tokenizer=self.tokenizer) 40 | self.tokenizer.tgt_lang = "fr" 41 | 42 | def test_pipeline(self): 43 | mbr_config = MBRConfig( 44 | num_samples=5, 45 | ) 46 | output = self.pipeline( 47 | "Could you translate this for me, please?", 48 | mbr_config=mbr_config, 49 | tokenizer=self.tokenizer, 50 | do_sample=True, 51 | num_beams=1, 52 | ) 53 | self.assertEqual(1, len(output)) 54 | self.assertIn("translation_text", output[0]) 55 | --------------------------------------------------------------------------------