├── sacrebleu ├── py.typed ├── tokenizers │ ├── __init__.py │ ├── tokenizer_none.py │ ├── tokenizer_char.py │ ├── tokenizer_base.py │ ├── tokenizer_13a.py │ ├── tokenizer_re.py │ ├── tokenizer_ja_mecab.py │ ├── tokenizer_ko_mecab.py │ ├── tokenizer_intl.py │ ├── tokenizer_spm.py │ ├── tokenizer_zh.py │ └── tokenizer_ter.py ├── dataset │ ├── iwslt_xml.py │ ├── plain_text.py │ ├── __main__.py │ ├── tsv.py │ ├── fake_sgml.py │ ├── base.py │ └── wmt_xml.py ├── metrics │ ├── __init__.py │ ├── helpers.py │ ├── ter.py │ ├── chrf.py │ ├── lib_ter.py │ ├── base.py │ └── bleu.py ├── __main__.py ├── __init__.py ├── compat.py └── significance.py ├── pytest.ini ├── test ├── wmt17_en_de_systems.pkl.bz2 ├── test_tokenizer_ter.py ├── test_ter.py ├── test_sentence_bleu.py ├── test_api.py ├── test_dataset.py ├── test_significance.py ├── test_chrf.py └── test_bleu.py ├── .gitignore ├── Makefile ├── tox.ini ├── mypy.ini ├── scripts ├── add_wmt.sh └── perf_test.py ├── .github └── workflows │ ├── python-publish.yml │ └── check-build.yml ├── pyproject.toml ├── DATASETS.md └── LICENSE.txt /sacrebleu/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -v 3 | testpaths = test 4 | -------------------------------------------------------------------------------- /test/wmt17_en_de_systems.pkl.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mjpost/sacrebleu/HEAD/test/wmt17_en_de_systems.pkl.bz2 -------------------------------------------------------------------------------- /sacrebleu/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Base tokenizer to derive from 2 | from .tokenizer_base import BaseTokenizer # noqa: F401 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .coverage 3 | build 4 | dist 5 | __pycache__ 6 | sacrebleu.egg-info 7 | .sacrebleu 8 | *~ 9 | .DS_Store 10 | .idea/ 11 | sacrebleu/version.py 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: 3 | mypy sacrebleu scripts test 4 | python3 -m pytest 5 | bash test.sh 6 | 7 | pip: 8 | python3 -m build . 9 | 10 | publish: pip 11 | twine upload dist/* 12 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501,E265 3 | 4 | [testenv:pkglint] 5 | base_python=python3.11 6 | deps= 7 | build 8 | twine 9 | commands= 10 | python -m build 11 | twine check dist/*.tar.gz dist/*.whl 12 | -------------------------------------------------------------------------------- /sacrebleu/dataset/iwslt_xml.py: -------------------------------------------------------------------------------- 1 | from .fake_sgml import FakeSGMLDataset 2 | 3 | 4 | class IWSLTXMLDataset(FakeSGMLDataset): 5 | """IWSLT dataset format. Can be parsed with the lxml parser.""" 6 | 7 | # Same as FakeSGMLDataset. Nothing to do here. 8 | pass 9 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_none.py: -------------------------------------------------------------------------------- 1 | from .tokenizer_base import BaseTokenizer 2 | 3 | class NoneTokenizer(BaseTokenizer): 4 | """Don't apply any tokenization. Not recommended!.""" 5 | 6 | def signature(self): 7 | return 'none' 8 | 9 | def __call__(self, line): 10 | return line 11 | -------------------------------------------------------------------------------- /sacrebleu/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """The implementation of various metrics.""" 2 | 3 | from .bleu import BLEU, BLEUScore # noqa: F401 4 | from .chrf import CHRF, CHRFScore # noqa: F401 5 | from .ter import TER, TERScore # noqa: F401 6 | 7 | METRICS = { 8 | 'BLEU': BLEU, 9 | 'CHRF': CHRF, 10 | 'TER': TER, 11 | } 12 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_char.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from .tokenizer_base import BaseTokenizer 3 | 4 | 5 | class TokenizerChar(BaseTokenizer): 6 | def signature(self): 7 | return 'char' 8 | 9 | def __init__(self): 10 | pass 11 | 12 | @lru_cache(maxsize=2**16) 13 | def __call__(self, line): 14 | """Tokenizes all the characters in the input line. 15 | 16 | :param line: a segment to tokenize 17 | :return: the tokenized line 18 | """ 19 | return " ".join((char for char in line)) 20 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_base.py: -------------------------------------------------------------------------------- 1 | class BaseTokenizer: 2 | """A base dummy tokenizer to derive from.""" 3 | 4 | def signature(self): 5 | """ 6 | Returns a signature for the tokenizer. 7 | 8 | :return: signature string 9 | """ 10 | raise NotImplementedError() 11 | 12 | def __call__(self, line): 13 | """ 14 | Tokenizes an input line with the tokenizer. 15 | 16 | :param line: a segment to tokenize 17 | :return: the tokenized line 18 | """ 19 | raise NotImplementedError() 20 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.12 3 | 4 | [mypy-portalocker.*] 5 | ignore_missing_imports = True 6 | 7 | [mypy-colorama.*] 8 | ignore_missing_imports = True 9 | 10 | [mypy-numpy.*] 11 | ignore_missing_imports = True 12 | 13 | [mypy-regex.*] 14 | ignore_missing_imports = True 15 | 16 | [mypy-ipadic.*] 17 | ignore_missing_imports = True 18 | 19 | [mypy-MeCab.*] 20 | ignore_missing_imports = True 21 | 22 | [mypy-mecab_ko.*] 23 | ignore_missing_imports = True 24 | 25 | [mypy-mecab_ko_dic.*] 26 | ignore_missing_imports = True 27 | 28 | [mypy-sentencepiece.*] 29 | ignore_missing_imports = True 30 | -------------------------------------------------------------------------------- /scripts/add_wmt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z $1 ]]; then 4 | echo "Usage: add_wmt.sh TARBALL_URL" 5 | cat<', '') 25 | line = line.replace('-\n', '') 26 | line = line.replace('\n', ' ') 27 | 28 | if '&' in line: 29 | line = line.replace('"', '"') 30 | line = line.replace('&', '&') 31 | line = line.replace('<', '<') 32 | line = line.replace('>', '>') 33 | 34 | return self._post_tokenizer(f' {line} ') 35 | -------------------------------------------------------------------------------- /sacrebleu/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017--2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"). You may not 7 | # use this file except in compliance with the License. A copy of the License 8 | # is located at 9 | # 10 | # http://aws.amazon.com/apache2.0/ 11 | # 12 | # or in the "license" file accompanying this file. This file is distributed on 13 | # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 14 | # express or implied. See the License for the specific language governing 15 | # permissions and limitations under the License. 16 | 17 | """ 18 | SacreBLEU provides hassle-free computation of shareable, comparable, and reproducible BLEU scores. 19 | Inspired by Rico Sennrich's `multi-bleu-detok.perl`, it produces the official WMT scores but works with plain text. 20 | It also knows all the standard test sets and handles downloading, processing, and tokenization for you. 21 | 22 | See the [README.md] file for more information. 23 | """ 24 | from .sacrebleu import main 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | - push 8 | 9 | jobs: 10 | build-n-publish: 11 | name: Build and publish Python distributions to PyPI 12 | if: ${{ github.repository_owner == 'mjpost' }} 13 | 14 | runs-on: ubuntu-latest 15 | environment: release 16 | 17 | permissions: 18 | # IMPORTANT: this permission is mandatory for trusted publishing 19 | id-token: write 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | # with: 24 | # fetch-depth: 0 25 | - name: Set up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.11' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build sdist and wheel 34 | run: | 35 | python -m build 36 | - name: Publish distribution to PyPI 37 | if: startsWith(github.ref, 'refs/tags') 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | -------------------------------------------------------------------------------- /sacrebleu/dataset/plain_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ..utils import smart_open 4 | from .base import Dataset 5 | 6 | 7 | class PlainTextDataset(Dataset): 8 | """ 9 | The plain text format. Data is separated into source and reference files. 10 | Each line of the two files is aligned. 11 | """ 12 | 13 | def process_to_text(self, langpair=None): 14 | """Processes raw files to plain text files. 15 | 16 | :param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed. 17 | """ 18 | # ensure that the dataset is downloaded 19 | self.maybe_download() 20 | langpairs = self._get_langpair_metadata(langpair) 21 | 22 | for langpair in langpairs: 23 | fieldnames = self.fieldnames(langpair) 24 | origin_files = [ 25 | os.path.join(self._rawdir, path) for path in langpairs[langpair] 26 | ] 27 | 28 | for field, origin_file in zip(fieldnames, origin_files): 29 | 30 | origin_file = os.path.join(self._rawdir, origin_file) 31 | output_file = self._get_txt_file_path(langpair, field) 32 | 33 | with smart_open(origin_file) as fin: 34 | with smart_open(output_file, "wt") as fout: 35 | for line in fin: 36 | print(line.rstrip(), file=fout) 37 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_re.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | import re 3 | 4 | from .tokenizer_base import BaseTokenizer 5 | 6 | 7 | class TokenizerRegexp(BaseTokenizer): 8 | 9 | def signature(self): 10 | return 're' 11 | 12 | def __init__(self): 13 | self._re = [ 14 | # language-dependent part (assuming Western languages) 15 | (re.compile(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])'), r' \1 '), 16 | # tokenize period and comma unless preceded by a digit 17 | (re.compile(r'([^0-9])([\.,])'), r'\1 \2 '), 18 | # tokenize period and comma unless followed by a digit 19 | (re.compile(r'([\.,])([^0-9])'), r' \1 \2'), 20 | # tokenize dash when preceded by a digit 21 | (re.compile(r'([0-9])(-)'), r'\1 \2 '), 22 | # one space only between words 23 | # NOTE: Doing this in Python (below) is faster 24 | # (re.compile(r'\s+'), r' '), 25 | ] 26 | 27 | @lru_cache(maxsize=2**16) 28 | def __call__(self, line): 29 | """Common post-processing tokenizer for `13a` and `zh` tokenizers. 30 | 31 | :param line: a segment to tokenize 32 | :return: the tokenized line 33 | """ 34 | for (_re, repl) in self._re: 35 | line = _re.sub(repl, line) 36 | 37 | # no leading or trailing spaces, single space within words 38 | return ' '.join(line.split()) 39 | -------------------------------------------------------------------------------- /sacrebleu/dataset/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from . import DATASETS 4 | 5 | try: 6 | cmd = sys.argv[1] 7 | except IndexError: 8 | print(f"Usage: {sys.argv[0]} --check | --dump") 9 | sys.exit(1) 10 | 11 | if cmd == "--check": 12 | import hashlib 13 | import urllib.request 14 | 15 | url_md5 = {} 16 | 17 | for item in DATASETS.values(): 18 | if item.md5 is not None: 19 | assert item.data 20 | assert item.md5 21 | assert len(item.data) == len(item.md5) 22 | pairs = zip(item.data, item.md5) 23 | for url, md5_hash in pairs: 24 | url_md5[url] = md5_hash 25 | 26 | for url, md5_hash in url_md5.items(): 27 | try: 28 | print("Downloading ", url) 29 | with urllib.request.urlopen(url) as f: 30 | data = f.read() 31 | except Exception as exc: 32 | raise (exc) 33 | 34 | if hashlib.md5(data).hexdigest() != md5_hash: 35 | print("MD5 check failed for", url) 36 | elif cmd == "--dump": 37 | import re 38 | 39 | # Dumps a table in markdown format 40 | print(f'| {"Dataset":<30} | {"Description":<115} |') 41 | header = "| " + "-" * 30 + " | " + "-" * 115 + " |" 42 | print(header) 43 | for name, item in DATASETS.items(): 44 | desc = re.sub(r"(http[s]?:\/\/\S+)", r"[URL](\1)", str(item.description)) 45 | print(f"| {name:<30} | {desc:<115} |") 46 | -------------------------------------------------------------------------------- /.github/workflows/check-build.yml: -------------------------------------------------------------------------------- 1 | name: check-build 2 | 3 | on: 4 | push: 5 | pull_request: 6 | workflow_dispatch: 7 | 8 | env: 9 | PYTHONUTF8: "1" 10 | 11 | # only run one at a time per branch 12 | concurrency: 13 | group: check-build-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | check-build: 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | matrix: 21 | os: [ubuntu-latest, macos-latest, windows-latest] 22 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Setup Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - if: ${{ matrix.os == 'macos-latest' || matrix.os == 'macos-14' }} 30 | name: Install Mac OS requirements 31 | run: brew install bash 32 | - if: matrix.os == 'windows-latest' 33 | name: Install Windows requirements 34 | run: choco install wget unzip 35 | - name: Install python dependencies 36 | run: pip install ".[dev,ja,ko]" 37 | - name: Lint with Mypy 38 | run: mypy sacrebleu scripts test 39 | - name: Lint with Ruff 40 | uses: chartboost/ruff-action@v1 41 | - name: Python pytest test suite 42 | run: python -m pytest 43 | - name: CLI bash test suite 44 | shell: bash 45 | run: ./test.sh 46 | - name: Build 47 | run: | 48 | pip install build 49 | python -m build . 50 | -------------------------------------------------------------------------------- /test/test_tokenizer_ter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sacrebleu.tokenizers.tokenizer_ter import TercomTokenizer 4 | 5 | 6 | test_cases_default = [ 7 | ("a b c d", "a b c d"), 8 | ("", ""), 9 | ("a b c d.", "a b c d."), 10 | ("A B C D.", "a b c d."), 11 | ] 12 | 13 | test_cases_no_punct = [ 14 | ("a b c d.", "a b c d"), 15 | ("A ; B ) C : D.", "a b c d"), 16 | ] 17 | 18 | test_cases_norm = [ 19 | ("a b (c) d.", "a b ( c ) d ."), 20 | ("Jim's car.", "jim 's car ."), 21 | ("4.2", "4.2"), 22 | ] 23 | 24 | test_cases_asian = [ 25 | ("美众院公布对特", "美 众 院 公 布 对 特"), # Chinese 26 | ("りの拳銃を持", "りの 拳 銃 を 持"), # Japanese, first two letters are Hiragana 27 | ] 28 | 29 | 30 | @pytest.mark.parametrize("input, expected", test_cases_default) 31 | def test_ter_tokenizer_default(input, expected): 32 | tokenizer = TercomTokenizer() 33 | assert tokenizer(input) == expected 34 | 35 | 36 | @pytest.mark.parametrize("input, expected", test_cases_no_punct) 37 | def test_ter_tokenizer_nopunct(input, expected): 38 | tokenizer = TercomTokenizer(no_punct=True) 39 | assert tokenizer(input) == expected 40 | 41 | 42 | @pytest.mark.parametrize("input, expected", test_cases_norm) 43 | def test_ter_tokenizer_norm(input, expected): 44 | tokenizer = TercomTokenizer(normalized=True) 45 | assert tokenizer(input) == expected 46 | 47 | 48 | @pytest.mark.parametrize("input, expected", test_cases_asian) 49 | def test_ter_tokenizer_asian(input, expected): 50 | tokenizer = TercomTokenizer(normalized=True, asian_support=True) 51 | assert tokenizer(input) == expected 52 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_ja_mecab.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | try: 4 | import MeCab 5 | import ipadic 6 | except ImportError: 7 | # Don't fail until the tokenizer is actually used 8 | MeCab = None 9 | 10 | from .tokenizer_base import BaseTokenizer 11 | 12 | FAIL_MESSAGE = """ 13 | Japanese tokenization requires extra dependencies, but you do not have them installed. 14 | Please install them like so. 15 | 16 | pip install sacrebleu[ja] 17 | """ 18 | 19 | 20 | class TokenizerJaMecab(BaseTokenizer): 21 | def __init__(self): 22 | if MeCab is None: 23 | raise RuntimeError(FAIL_MESSAGE) 24 | self.tagger = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati") 25 | 26 | # make sure the dictionary is IPA 27 | d = self.tagger.dictionary_info() 28 | assert d.size == 392126, \ 29 | "Please make sure to use the IPA dictionary for MeCab" 30 | # This asserts that no user dictionary has been loaded 31 | assert d.next is None 32 | 33 | @lru_cache(maxsize=2**16) 34 | def __call__(self, line): 35 | """ 36 | Tokenizes an Japanese input line using MeCab morphological analyzer. 37 | 38 | :param line: a segment to tokenize 39 | :return: the tokenized line 40 | """ 41 | line = line.strip() 42 | sentence = self.tagger.parse(line).strip() 43 | return sentence 44 | 45 | def signature(self): 46 | """ 47 | Returns the MeCab parameters. 48 | 49 | :return: signature string 50 | """ 51 | signature = self.tagger.version() + "-IPA" 52 | return 'ja-mecab-' + signature 53 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_ko_mecab.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | try: 4 | import mecab_ko as MeCab 5 | import mecab_ko_dic 6 | except ImportError: 7 | # Don't fail until the tokenizer is actually used 8 | MeCab = None 9 | 10 | from .tokenizer_base import BaseTokenizer 11 | 12 | FAIL_MESSAGE = """ 13 | Korean tokenization requires extra dependencies, but you do not have them installed. 14 | Please install them like so. 15 | 16 | pip install sacrebleu[ko] 17 | """ 18 | 19 | 20 | class TokenizerKoMecab(BaseTokenizer): 21 | def __init__(self): 22 | if MeCab is None: 23 | raise RuntimeError(FAIL_MESSAGE) 24 | self.tagger = MeCab.Tagger(mecab_ko_dic.MECAB_ARGS + " -Owakati") 25 | 26 | # make sure the dictionary is mecab-ko-dic 27 | d = self.tagger.dictionary_info() 28 | assert d.size == 811795, \ 29 | "Please make sure to use the mecab-ko-dic for MeCab-ko" 30 | # This asserts that no user dictionary has been loaded 31 | assert d.next is None 32 | 33 | @lru_cache(maxsize=2**16) 34 | def __call__(self, line): 35 | """ 36 | Tokenizes an Korean input line using MeCab-ko morphological analyzer. 37 | 38 | :param line: a segment to tokenize 39 | :return: the tokenized line 40 | """ 41 | line = line.strip() 42 | sentence = self.tagger.parse(line).strip() 43 | return sentence 44 | 45 | def signature(self): 46 | """ 47 | Returns the MeCab-ko parameters. 48 | 49 | :return: signature string 50 | """ 51 | signature = self.tagger.version() + "-KO" 52 | return 'ko-mecab-' + signature 53 | -------------------------------------------------------------------------------- /sacrebleu/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017--2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"). You may not 7 | # use this file except in compliance with the License. A copy of the License 8 | # is located at 9 | # 10 | # http://aws.amazon.com/apache2.0/ 11 | # 12 | # or in the "license" file accompanying this file. This file is distributed on 13 | # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 14 | # express or implied. See the License for the specific language governing 15 | # permissions and limitations under the License. 16 | 17 | __description__ = "Hassle-free computation of shareable, comparable, and reproducible BLEU, chrF, and TER scores" 18 | 19 | 20 | # Backward compatibility functions for old style API access (<= 1.4.10) 21 | from .compat import ( 22 | corpus_bleu, 23 | corpus_chrf, 24 | corpus_ter, 25 | raw_corpus_bleu, 26 | sentence_bleu, 27 | sentence_chrf, 28 | sentence_ter, 29 | ) 30 | from .dataset import DATASETS 31 | from .metrics import BLEU, CHRF, TER 32 | from .metrics.helpers import extract_char_ngrams, extract_word_ngrams 33 | from .utils import ( 34 | SACREBLEU_DIR, 35 | download_test_set, 36 | get_available_testsets, 37 | get_langpairs_for_testset, 38 | get_reference_files, 39 | get_source_file, 40 | smart_open, 41 | ) 42 | from .version import __version__ 43 | 44 | __all__ = [ 45 | "smart_open", 46 | "SACREBLEU_DIR", 47 | "download_test_set", 48 | "get_source_file", 49 | "get_reference_files", 50 | "get_available_testsets", 51 | "get_langpairs_for_testset", 52 | "extract_word_ngrams", 53 | "extract_char_ngrams", 54 | "DATASETS", 55 | "BLEU", 56 | "CHRF", 57 | "TER", 58 | "corpus_bleu", 59 | "raw_corpus_bleu", 60 | "sentence_bleu", 61 | "corpus_chrf", 62 | "sentence_chrf", 63 | "corpus_ter", 64 | "sentence_ter", 65 | "__version__", 66 | ] 67 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_intl.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import regex 4 | 5 | from .tokenizer_base import BaseTokenizer 6 | 7 | 8 | class TokenizerV14International(BaseTokenizer): 9 | """Tokenizes a string following the official BLEU implementation. 10 | 11 | See github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 12 | 13 | In our case, the input string is expected to be just one line. 14 | We just tokenize on punctuation and symbols, 15 | except when a punctuation is preceded and followed by a digit 16 | (e.g. a comma/dot as a thousand/decimal separator). 17 | We do not recover escaped forms of punctuations such as ' or > 18 | as these should never appear in MT system outputs (see issue #138) 19 | 20 | Note that a number (e.g., a year) followed by a dot at the end of 21 | sentence is NOT tokenized, i.e. the dot stays with the number because 22 | `s/(\\p{P})(\\P{N})/ $1 $2/g` does not match this case (unless we add a 23 | space after each sentence). However, this error is already in the 24 | original mteval-v14.pl and we want to be consistent with it. 25 | The error is not present in the non-international version, 26 | which uses `$norm_text = " $norm_text "`. 27 | 28 | :param line: the input string to tokenize. 29 | :return: The tokenized string. 30 | """ 31 | 32 | def signature(self): 33 | return 'intl' 34 | 35 | def __init__(self): 36 | self._re = [ 37 | # Separate out punctuations preceeded by a non-digit 38 | (regex.compile(r'(\P{N})(\p{P})'), r'\1 \2 '), 39 | # Separate out punctuations followed by a non-digit 40 | (regex.compile(r'(\p{P})(\P{N})'), r' \1 \2'), 41 | # Separate out symbols 42 | (regex.compile(r'(\p{S})'), r' \1 '), 43 | ] 44 | 45 | @lru_cache(maxsize=2**16) 46 | def __call__(self, line: str) -> str: 47 | for (_re, repl) in self._re: 48 | line = _re.sub(repl, line) 49 | 50 | return ' '.join(line.split()) 51 | -------------------------------------------------------------------------------- /sacrebleu/dataset/tsv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ..utils import smart_open 4 | from .base import Dataset 5 | 6 | 7 | class TSVDataset(Dataset): 8 | """ 9 | The format used by the MTNT datasets. Data is in a single TSV file. 10 | """ 11 | 12 | @staticmethod 13 | def _split_index_and_filename(meta, field): 14 | """ 15 | Splits the index and filename from a metadata string. 16 | 17 | e.g. meta="3:en-de.tsv", filed=[Any value] -> (3, "en-de.tsv") 18 | "en-de.tsv", filed="src" -> (1, "en-de.tsv") 19 | "en-de.tsv", filed="tgt" -> (2, "en-de.tsv") 20 | """ 21 | arr = meta.split(":") 22 | if len(arr) == 2: 23 | try: 24 | index = int(arr[0]) 25 | except ValueError: 26 | raise Exception(f"Invalid meta for TSVDataset: {meta}") 27 | return index, arr[1] 28 | 29 | else: 30 | index = 0 if field == "src" else 1 31 | return index, meta 32 | 33 | def process_to_text(self, langpair=None): 34 | """Processes raw files to plain text files. 35 | 36 | :param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed. 37 | """ 38 | # ensure that the dataset is downloaded 39 | self.maybe_download() 40 | langpairs = self._get_langpair_metadata(langpair) 41 | 42 | for langpair in langpairs: 43 | fieldnames = self.fieldnames(langpair) 44 | origin_files = [ 45 | os.path.join(self._rawdir, path) for path in langpairs[langpair] 46 | ] 47 | 48 | for field, origin_file, meta in zip( 49 | fieldnames, origin_files, langpairs[langpair] 50 | ): 51 | index, origin_file = self._split_index_and_filename(meta, field) 52 | 53 | origin_file = os.path.join(self._rawdir, origin_file) 54 | output_file = self._get_txt_file_path(langpair, field) 55 | 56 | with smart_open(origin_file) as fin: 57 | with smart_open(output_file, "wt") as fout: 58 | for line in fin: 59 | # be careful with empty source or reference lines 60 | # MTNT2019/ja-en.final.tsv:632 `'1033\t718\t\t\n'` 61 | print(line.rstrip("\n").split("\t")[index], file=fout) 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=77", "setuptools_scm>=8"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sacrebleu" 7 | dynamic = ["version"] 8 | authors = [{ name = "Matt Post", email = "post@cs.jhu.edu" }] 9 | maintainers = [{ name = "Matt Post", email = "post@cs.jhu.edu" }] 10 | description = "Hassle-free computation of shareable, comparable, and reproducible BLEU, chrF, and TER scores" 11 | readme = "README.md" 12 | license = "Apache-2.0" 13 | classifiers = [ 14 | # How mature is this project? Common values are 15 | # 3 - Alpha 16 | # 4 - Beta 17 | # 5 - Production/Stable 18 | "Development Status :: 5 - Production/Stable", 19 | 20 | # Indicate who your project is intended for 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Science/Research", 23 | "Topic :: Scientific/Engineering", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "Topic :: Text Processing", 26 | 27 | # List operating systems 28 | "Operating System :: POSIX", 29 | "Operating System :: MacOS :: MacOS X", 30 | "Operating System :: Microsoft :: Windows", 31 | 32 | # Specify the Python versions you support here. In particular, ensure 33 | # that you indicate whether you support Python 2, Python 3 or both. 34 | "Programming Language :: Python :: 3 :: Only", 35 | 36 | # Indicate that type hints are provided 37 | "Typing :: Typed", 38 | ] 39 | 40 | requires-python = ">=3.9" 41 | 42 | keywords = [ 43 | "machine translation", 44 | "evaluation", 45 | "NLP", 46 | "natural language processing", 47 | "computational linguistics", 48 | ] 49 | 50 | dependencies = [ 51 | "portalocker", 52 | "regex", 53 | "tabulate>=0.8.9", 54 | "numpy>=1.17", 55 | "colorama", 56 | "lxml", 57 | ] 58 | 59 | [project.optional-dependencies] 60 | dev = ["wheel", "pytest", "mypy", "types-tabulate", "lxml-stubs", "setuptools"] 61 | ja = ["mecab-python3>=1.0.9,<2.0.0", "ipadic>=1.0,<2.0"] 62 | ko = ["mecab-ko>=1.0.2,<2.0.0", "mecab-ko-dic>=1.0,<2.0"] 63 | 64 | [project.scripts] 65 | sacrebleu = "sacrebleu.sacrebleu:main" 66 | 67 | [project.urls] 68 | Repository = "https://github.com/mjpost/sacrebleu" 69 | 70 | [tool.setuptools.packages.find] 71 | include = ["sacrebleu*"] 72 | 73 | [tool.setuptools.package-data] 74 | sacrebleu = ["py.typed"] 75 | 76 | [tool.setuptools_scm] 77 | version_file = "sacrebleu/version.py" 78 | -------------------------------------------------------------------------------- /test/test_ter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sacrebleu 3 | 4 | EPSILON = 1e-3 5 | 6 | test_cases = [ 7 | (['aaaa bbbb cccc dddd'], ['aaaa bbbb cccc dddd'], 0), # perfect match 8 | (['dddd eeee ffff'], ['aaaa bbbb cccc'], 1), # no overlap 9 | ([''], ['a'], 1), # corner case, empty hypothesis 10 | (['a'], [''], 1), # corner case, empty reference 11 | ([''], [''], 0), # corner case, both reference and hypothesis empty, we define it as 0.0 12 | (['d e f g h a b c'], ['a b c d e f g h'], 1 / 8), # a single shift fixes MT 13 | ( 14 | [ 15 | 'wählen Sie " Bild neu berechnen , " um beim Ändern der Bildgröße Pixel hinzuzufügen oder zu entfernen , damit das Bild ungefähr dieselbe Größe aufweist wie die andere Größe .', 16 | 'wenn Sie alle Aufgaben im aktuellen Dokument aktualisieren möchten , wählen Sie im Menü des Aufgabenbedienfelds die Option " Alle Aufgaben aktualisieren . "', 17 | 'klicken Sie auf der Registerkarte " Optionen " auf die Schaltfläche " Benutzerdefiniert " und geben Sie Werte für " Fehlerkorrektur-Level " und " Y / X-Verhältnis " ein .', 18 | 'Sie können beispielsweise ein Dokument erstellen , das ein Auto über die Bühne enthält .', 19 | 'wählen Sie im Dialogfeld " Neu aus Vorlage " eine Vorlage aus und klicken Sie auf " Neu . "', 20 | ], 21 | [ 22 | 'wählen Sie " Bild neu berechnen , " um beim Ändern der Bildgröße Pixel hinzuzufügen oder zu entfernen , damit die Darstellung des Bildes in einer anderen Größe beibehalten wird .', 23 | 'wenn Sie alle Aufgaben im aktuellen Dokument aktualisieren möchten , wählen Sie im Menü des Aufgabenbedienfelds die Option " Alle Aufgaben aktualisieren . "', 24 | 'klicken Sie auf der Registerkarte " Optionen " auf die Schaltfläche " Benutzerdefiniert " und geben Sie für " Fehlerkorrektur-Level " und " Y / X-Verhältnis " niedrigere Werte ein .', 25 | 'Sie können beispielsweise ein Dokument erstellen , das ein Auto enthalt , das sich über die Bühne bewegt .', 26 | 'wählen Sie im Dialogfeld " Neu aus Vorlage " eine Vorlage aus und klicken Sie auf " Neu . "', 27 | ], 28 | 0.136 # realistic example from WMT dev data (2019) 29 | ), 30 | ] 31 | 32 | 33 | @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases) 34 | def test_ter(hypotheses, references, expected_score): 35 | metric = sacrebleu.metrics.TER() 36 | score = metric.corpus_score(hypotheses, [references]).score 37 | assert abs(score - 100 * expected_score) < EPSILON 38 | -------------------------------------------------------------------------------- /sacrebleu/metrics/helpers.py: -------------------------------------------------------------------------------- 1 | """Various utility functions for word and character n-gram extraction.""" 2 | 3 | from collections import Counter 4 | from typing import List, Tuple 5 | 6 | 7 | def extract_all_word_ngrams(line: str, min_order: int, max_order: int) -> Tuple[Counter, int]: 8 | """Extracts all ngrams (min_order <= n <= max_order) from a sentence. 9 | 10 | :param line: A string sentence. 11 | :param min_order: Minimum n-gram order. 12 | :param max_order: Maximum n-gram order. 13 | :return: a Counter object with n-grams counts and the sequence length. 14 | """ 15 | 16 | ngrams = [] 17 | tokens = line.split() 18 | 19 | for n in range(min_order, max_order + 1): 20 | for i in range(0, len(tokens) - n + 1): 21 | ngrams.append(tuple(tokens[i: i + n])) 22 | 23 | return Counter(ngrams), len(tokens) 24 | 25 | 26 | def extract_word_ngrams(tokens: List[str], n: int) -> Counter: 27 | """Extracts n-grams with order `n` from a list of tokens. 28 | 29 | :param tokens: A list of tokens. 30 | :param n: The order of n-grams. 31 | :return: a Counter object with n-grams counts. 32 | """ 33 | return Counter([' '.join(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]) 34 | 35 | 36 | def extract_char_ngrams(line: str, n: int, include_whitespace: bool = False) -> Counter: 37 | """Yields counts of character n-grams from a sentence. 38 | 39 | :param line: A segment containing a sequence of words. 40 | :param n: The order of the n-grams. 41 | :param include_whitespace: If given, will not strip whitespaces from the line. 42 | :return: a dictionary containing ngrams and counts 43 | """ 44 | if not include_whitespace: 45 | line = ''.join(line.split()) 46 | 47 | return Counter([line[i:i + n] for i in range(len(line) - n + 1)]) 48 | 49 | 50 | def extract_all_char_ngrams( 51 | line: str, max_order: int, include_whitespace: bool = False) -> List[Counter]: 52 | """Extracts all character n-grams at once for convenience. 53 | 54 | :param line: A segment containing a sequence of words. 55 | :param max_order: The maximum order of the n-grams. 56 | :param include_whitespace: If given, will not strip whitespaces from the line. 57 | :return: a list of Counter objects containing ngrams and counts. 58 | """ 59 | 60 | counters = [] 61 | 62 | if not include_whitespace: 63 | line = ''.join(line.split()) 64 | 65 | for n in range(1, max_order + 1): 66 | ngrams = Counter([line[i:i + n] for i in range(len(line) - n + 1)]) 67 | counters.append(ngrams) 68 | 69 | return counters 70 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_spm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import logging 5 | 6 | from functools import lru_cache 7 | from ..utils import SACREBLEU_DIR, download_file 8 | from .tokenizer_base import BaseTokenizer 9 | 10 | sacrelogger = logging.getLogger('sacrebleu') 11 | 12 | 13 | SPM_MODELS = { 14 | "spm": { 15 | "url": "https://dl.fbaipublicfiles.com/fairseq/models/flores/sacrebleu_tokenizer_spm.model", 16 | "signature": "flores101", 17 | }, 18 | # same as the default of "spm" 19 | "flores101": { 20 | "url": "https://dl.fbaipublicfiles.com/fairseq/models/flores/sacrebleu_tokenizer_spm.model", 21 | "signature": "flores101", 22 | }, 23 | "flores200": { 24 | "url": "https://tinyurl.com/flores200sacrebleuspm", 25 | "signature": "flores200", 26 | }, 27 | ### Added for spBLEU-1K tokenizer by AbdelRahim Elmadany 28 | "spBLEU-1K": { 29 | "url": "https://www.dlnlp.ai/spBLEU-1K/spbleu-1k_tokenizer_spm.model", 30 | "signature": "spBLEU-1K", 31 | }, 32 | } 33 | 34 | class TokenizerSPM(BaseTokenizer): 35 | def signature(self): 36 | return self.name 37 | 38 | def __init__(self, key="spm"): 39 | self.name = SPM_MODELS[key]["signature"] 40 | 41 | if key == "spm": 42 | sacrelogger.warn("Tokenizer 'spm' has been changed to 'flores101', and may be removed in the future.") 43 | 44 | try: 45 | import sentencepiece as spm 46 | except (ImportError, ModuleNotFoundError): 47 | raise ImportError( 48 | '\n\nPlease install the sentencepiece library for SPM tokenization:' 49 | '\n\n pip install sentencepiece ' 50 | ) 51 | self.sp = spm.SentencePieceProcessor() 52 | 53 | model_path = os.path.join(SACREBLEU_DIR, "models", os.path.basename(SPM_MODELS[key]["url"])) 54 | if not os.path.exists(model_path): 55 | url = SPM_MODELS[self.name]["url"] 56 | download_file(url, model_path) 57 | self.sp.Load(model_path) 58 | 59 | @lru_cache(maxsize=2**16) 60 | def __call__(self, line): 61 | """Tokenizes all the characters in the input line. 62 | 63 | :param line: a segment to tokenize 64 | :return: the tokenized line 65 | """ 66 | return " ".join(self.sp.EncodeAsPieces(line)) 67 | 68 | 69 | class Flores200Tokenizer(TokenizerSPM): 70 | def __init__(self): 71 | super().__init__("flores200") 72 | 73 | class Flores101Tokenizer(TokenizerSPM): 74 | def __init__(self): 75 | super().__init__("flores101") 76 | 77 | ### Added for spBLEU-1K tokenizer by AbdelRahim Elmadany 78 | class spBLEU1KTokenizer(TokenizerSPM): 79 | def __init__(self): 80 | super().__init__("spBLEU-1K") -------------------------------------------------------------------------------- /test/test_sentence_bleu.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sacrebleu 3 | 4 | EPSILON = 1e-3 5 | 6 | 7 | # Example taken from #98 8 | REF = "producţia de zahăr brut se exprimă în zahăr alb;" 9 | SYS = "Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;" 10 | 11 | test_cases = [ 12 | # change smoothing 13 | ('exp', None, False, '13a', 8.493), 14 | ('none', None, False, '13a', 0.0), 15 | ('floor', None, False, '13a', 4.51688), # defaults to 0.1 16 | ('floor', 0.1, False, '13a', 4.51688), 17 | ('floor', 0.5, False, '13a', 10.10), 18 | ('add-k', None, False, '13a', 14.882), # defaults to 1 19 | ('add-k', 1, False, '13a', 14.882), 20 | ('add-k', 2, False, '13a', 21.389), 21 | # change tok 22 | ('exp', None, False, 'none', 7.347), 23 | ('exp', None, False, 'intl', 8.493), 24 | ('exp', None, False, 'char', 40.8759), 25 | # change case 26 | ('exp', None, True, 'char', 42.0267), 27 | ] 28 | 29 | 30 | # Example taken from #141 31 | REF_0 = "okay thanks" 32 | SYS_0 = "this is a cat" 33 | 34 | test_cases_zero_bleu = [ 35 | ('exp', None, False, '13a', 0.0), 36 | ('none', None, False, '13a', 0.0), 37 | ('floor', None, False, '13a', 0.0), # defaults to 0.1 38 | ('floor', 0.1, False, '13a', 0.0), 39 | ('add-k', None, False, '13a', 0.0), # defaults to 1 40 | ('add-k', 1, False, '13a', 0.0), 41 | ] 42 | 43 | 44 | @pytest.mark.parametrize("smooth_method, smooth_value, lowercase, tok, expected_score", test_cases) 45 | def test_compat_sentence_bleu(smooth_method, smooth_value, lowercase, tok, expected_score): 46 | score = sacrebleu.compat.sentence_bleu( 47 | SYS, [REF], smooth_method=smooth_method, smooth_value=smooth_value, 48 | tokenize=tok, 49 | lowercase=lowercase, 50 | use_effective_order=True) 51 | assert abs(score.score - expected_score) < EPSILON 52 | 53 | 54 | @pytest.mark.parametrize("smooth_method, smooth_value, lowercase, tok, expected_score", test_cases) 55 | def test_api_sentence_bleu(smooth_method, smooth_value, lowercase, tok, expected_score): 56 | metric = sacrebleu.metrics.BLEU( 57 | lowercase=lowercase, force=False, tokenize=tok, 58 | smooth_method=smooth_method, smooth_value=smooth_value, 59 | effective_order=True) 60 | score = metric.sentence_score(SYS, [REF]) 61 | 62 | assert abs(score.score - expected_score) < EPSILON 63 | 64 | 65 | @pytest.mark.parametrize("smooth_method, smooth_value, lowercase, tok, expected_score", test_cases_zero_bleu) 66 | def test_api_sentence_bleu_zero(smooth_method, smooth_value, lowercase, tok, expected_score): 67 | metric = sacrebleu.metrics.BLEU( 68 | lowercase=lowercase, force=False, tokenize=tok, 69 | smooth_method=smooth_method, smooth_value=smooth_value, 70 | effective_order=True) 71 | score = metric.sentence_score(SYS_0, [REF_0]) 72 | assert abs(score.score - expected_score) < EPSILON 73 | -------------------------------------------------------------------------------- /scripts/perf_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import time 4 | import statistics 5 | 6 | sys.path.insert(0, '.') 7 | 8 | import sacrebleu # noqa: E402 9 | from sacrebleu.metrics import BLEU, CHRF # noqa: E402 10 | 11 | 12 | N_REPEATS = 5 13 | 14 | 15 | sys_files = [ 16 | 'data/wmt17-submitted-data/txt/system-outputs/newstest2017/cs-en/newstest2017.PJATK.4760.cs-en', 17 | 'data/wmt17-submitted-data/txt/system-outputs/newstest2017/cs-en/newstest2017.uedin-nmt.4955.cs-en', 18 | 'data/wmt17-submitted-data/txt/system-outputs/newstest2017/cs-en/newstest2017.online-A.0.cs-en', 19 | 'data/wmt17-submitted-data/txt/system-outputs/newstest2017/cs-en/newstest2017.online-B.0.cs-en', 20 | ] 21 | 22 | ref_files = ['data/wmt17-submitted-data/txt/references/newstest2017-csen-ref.en'] 23 | 24 | metrics = [ 25 | # BLEU 26 | (BLEU, {}), 27 | (BLEU, {'tokenize': 'intl'}), 28 | (BLEU, {'tokenize': 'none', 'force': True}), 29 | # CHRF 30 | (CHRF, {}), 31 | (CHRF, {'whitespace': True}), 32 | # CHRF++ 33 | (CHRF, {'word_order': 2}), 34 | ] 35 | 36 | 37 | def create_metric(klass, kwargs, refs=None): 38 | if refs: 39 | # caching mode 40 | kwargs['references'] = refs 41 | return klass(**kwargs) 42 | 43 | 44 | def read_files(*args): 45 | lines = [] 46 | for fname in args: 47 | cur_lines = [] 48 | with open(fname) as f: 49 | for line in f: 50 | cur_lines.append(line.strip()) 51 | lines.append(cur_lines) 52 | return lines 53 | 54 | 55 | def measure(metric_klass, metric_kwargs, systems, refs, cache=False): 56 | scores = [] 57 | durations = [] 58 | 59 | if cache: 60 | # caching mode 61 | metric_kwargs['references'] = refs 62 | st = time.time() 63 | metric = metric_klass(**metric_kwargs) 64 | 65 | for system in systems: 66 | sc = metric.corpus_score(system, None if cache else refs).score 67 | dur = time.time() - st 68 | print(f'{dur:.3f}', end=' ') 69 | durations.append(dur) 70 | scores.append(sc) 71 | st = time.time() 72 | 73 | durations = sorted(durations) 74 | median = durations[len(durations) // 2] 75 | std = statistics.pstdev(durations) 76 | mean = sum(durations) / len(durations) 77 | print(f' || mean: {mean:.3f} -- median: {median:.3f} -- stdev: {std:.3f}') 78 | 79 | 80 | if __name__ == '__main__': 81 | 82 | systems = read_files(*sys_files) 83 | refs = read_files(*ref_files) 84 | 85 | msg = f'SacreBLEU {sacrebleu.__version__} performance tests' 86 | print('-' * len(msg) + '\n' + msg + '\n' + '-' * len(msg)) 87 | 88 | for klass, kwargs in metrics: 89 | print(klass.__name__, kwargs) 90 | 91 | print(' > [no-cache] ', end='') 92 | measure(klass, kwargs, systems, refs, cache=False) 93 | 94 | print(' > [cached] ', end='') 95 | measure(klass, kwargs, systems, refs, cache=True) 96 | -------------------------------------------------------------------------------- /test/test_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You may not 4 | # use this file except in compliance with the License. A copy of the License 5 | # is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed on 10 | # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytest 15 | 16 | from sacrebleu.utils import get_available_testsets, get_available_testsets_for_langpair, get_langpairs_for_testset 17 | from sacrebleu.utils import get_source_file, get_reference_files 18 | from sacrebleu.dataset import DATASETS 19 | 20 | test_api_get_data = [ 21 | ("wmt19", "de-en", 1, "Schöne Münchnerin 2018: Schöne Münchnerin 2018 in Hvar: Neun Dates", "The Beauty of Munich 2018: the Beauty of Munich 2018 in Hvar: Nine dates"), 22 | ("mtnt1.1/train", "ja-en", 10, "0歳から100歳の女性が登場する海外のスキンケアCM", "The overseas skin care commercial in which 0 to 100 year old females appear."), 23 | ("wmt19/google/ar", "en-de", 1, "Welsh AMs worried about 'looking like muppets'", "Walisische Abgeordnete befürchten als ,Idioten’ dazustehen."), 24 | ] 25 | 26 | 27 | @pytest.mark.parametrize("testset, langpair, sentno, source, reference", test_api_get_data) 28 | def test_api_get_source(testset, langpair, sentno, source, reference): 29 | with open(get_source_file(testset, langpair)) as fh: 30 | line = fh.readlines()[sentno - 1].strip() 31 | assert line == source 32 | 33 | 34 | @pytest.mark.parametrize("testset, langpair, sentno, source, reference", test_api_get_data) 35 | def test_api_get_reference(testset, langpair, sentno, source, reference): 36 | with open(get_reference_files(testset, langpair)[0]) as fh: 37 | line = fh.readlines()[sentno - 1].strip() 38 | assert line == reference 39 | 40 | 41 | def test_api_get_available_testsets(): 42 | """ 43 | Loop over the datasets directly, and ensure the API function returns 44 | the test sets found. 45 | """ 46 | available = get_available_testsets() 47 | assert isinstance(available, list) 48 | assert "wmt19" in available 49 | assert "wmt05" not in available 50 | 51 | for testset in DATASETS.keys(): 52 | assert testset in available 53 | assert "slashdot_" + testset not in available 54 | 55 | 56 | def test_api_get_available_testsets_for_langpair(): 57 | """ 58 | Loop over the datasets directly, and ensure the API function returns 59 | the test sets found. 60 | """ 61 | available = get_available_testsets_for_langpair('en-it') 62 | assert isinstance(available, list) 63 | assert "wmt09" in available 64 | assert "wmt15" not in available 65 | 66 | available = get_available_testsets_for_langpair('en-fr') 67 | assert isinstance(available, list) 68 | assert "wmt11" in available 69 | assert "mtedx/test" in available 70 | assert "wmt20" not in available 71 | 72 | 73 | def test_api_get_langpairs_for_testset(): 74 | """ 75 | Loop over the datasets directly, and ensure the API function 76 | returns each language pair in each test set. 77 | """ 78 | for testset in DATASETS.keys(): 79 | available = get_langpairs_for_testset(testset) 80 | assert isinstance(available, list) 81 | for langpair in DATASETS[testset].langpairs.keys(): 82 | # skip non-language keys 83 | if "-" not in langpair: 84 | assert langpair not in available 85 | else: 86 | assert langpair in available 87 | assert "slashdot_" + langpair not in available 88 | -------------------------------------------------------------------------------- /test/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | 5 | import sacrebleu.dataset as dataset 6 | from sacrebleu.utils import smart_open 7 | 8 | 9 | def test_maybe_download(): 10 | """ 11 | Test the maybe_download function in Dataset class. 12 | 13 | Check a few random datasets for downloading and correct file placement. 14 | """ 15 | # ensure all file have been downloaded 16 | selected_datasets = random.choices(list(dataset.DATASETS.values()), k=10) 17 | for ds in selected_datasets: 18 | shutil.rmtree(ds._rawdir, ignore_errors=True) 19 | ds.maybe_download() 20 | 21 | all_files = os.listdir(ds._rawdir) 22 | for url in ds.data: 23 | filename = ds._get_tarball_filename(url) 24 | assert filename in all_files 25 | filepath = os.path.join(ds._rawdir, filename) 26 | assert os.path.getsize(filepath) > 0 27 | 28 | 29 | def test_process_to_text(): 30 | """ 31 | Test the function `process_to_text` in Dataset class. 32 | 33 | Ensure each field of specified language pair have the same length. 34 | """ 35 | selected_datasets = random.choices(list(dataset.DATASETS.values()), k=10) 36 | for ds in selected_datasets: 37 | if os.path.exists(ds._outdir): 38 | for filename in os.listdir(ds._outdir): 39 | filepath = os.path.join(ds._outdir, filename) 40 | if os.path.isfile(filepath): 41 | os.remove(filepath) 42 | 43 | ds.process_to_text() 44 | 45 | for pair in ds.langpairs: 46 | all_files = ds.get_files(pair) 47 | 48 | # count the number of lines in each file 49 | num_lines = [sum(1 for _ in smart_open(f)) for f in all_files] 50 | 51 | # ensure no empty file 52 | assert num_lines[0] > 0 53 | 54 | # assert each file has the same length 55 | assert all(x == num_lines[0] for x in num_lines) 56 | 57 | 58 | def test_get_files_and_fieldnames(): 59 | """ 60 | Test the functions `get_files` and `fieldnames` in Dataset class. 61 | 62 | Ensure the length of the returned list is correct. 63 | `get_files()` should return the same number of items as `fieldnames()`. 64 | """ 65 | for ds in dataset.DATASETS.values(): 66 | for pair in ds.langpairs: 67 | assert len(ds.get_files(pair)) == len(ds.fieldnames(pair)) 68 | 69 | 70 | def test_source_and_references(): 71 | """ 72 | Test the functions `source` and `references` in Dataset class. 73 | 74 | Ensure the length of source and references are equal. 75 | """ 76 | for ds in dataset.DATASETS.values(): 77 | for pair in ds.langpairs: 78 | src_len = len(list(ds.source(pair))) 79 | ref_len = len(list(ds.references(pair))) 80 | assert src_len == ref_len, f"source/reference failure for {ds.name}:{pair} len(source)={src_len} len(references)={ref_len}" 81 | 82 | 83 | def test_wmt22_references(): 84 | """ 85 | WMT21 added the ability to specify which reference to use (among many in the XML). 86 | The default was "A" for everything. 87 | WMT22 added the ability to override this default on a per-langpair basis, by 88 | replacing the langpair list of paths with a dict that had the list of paths and 89 | the annotator override. 90 | """ 91 | wmt22 = dataset.DATASETS["wmt22"] 92 | 93 | # make sure CS-EN returns all reference fields 94 | cs_en_fields = wmt22.fieldnames("cs-en") 95 | for ref in ["ref:B", "ref:C"]: 96 | assert ref in cs_en_fields 97 | assert "ref:A" not in cs_en_fields 98 | 99 | # make sure ref:B is the one used by default 100 | assert wmt22._get_langpair_allowed_refs("cs-en") == ["ref:B"] 101 | 102 | # similar check for another dataset: there should be no default ("A"), 103 | # and the only ref found should be the unannotated one 104 | assert "ref:A" not in wmt22.fieldnames("liv-en") 105 | assert "ref" in wmt22.fieldnames("liv-en") 106 | 107 | # and that ref:A is the default for all languages where it wasn't overridden 108 | for langpair, langpair_data in wmt22.langpairs.items(): 109 | if isinstance(langpair_data, dict): 110 | assert wmt22._get_langpair_allowed_refs(langpair) != ["ref:A"] 111 | else: 112 | assert wmt22._get_langpair_allowed_refs(langpair) == ["ref:A"] 113 | 114 | 115 | -------------------------------------------------------------------------------- /test/test_significance.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from collections import defaultdict 4 | from typing import DefaultDict 5 | 6 | from sacrebleu.metrics import BLEU 7 | from sacrebleu.significance import PairedTest, Result 8 | 9 | import pytest 10 | 11 | 12 | def _read_pickle_file(): 13 | import bz2 14 | import pickle as pkl 15 | with bz2.BZ2File('./test/wmt17_en_de_systems.pkl.bz2', 'rb') as f: 16 | data = pkl.load(f) 17 | return data 18 | 19 | 20 | # P-values obtained from Moses' significance script (mean of 3 runs) 21 | # Script: scripts/moses-sigdiff.pl (modified to bootstrap samples = 2000) 22 | MOSES_P_VALS = { 23 | "newstest2017.C-3MA.4959.en-de": 0.00000, 24 | "newstest2017.FBK.4870.en-de": 0.01267, 25 | "newstest2017.KIT.4950.en-de": 0.02233, 26 | "newstest2017.LMU-nmt-reranked.4934.en-de": 0.04383, 27 | "newstest2017.LMU-nmt-single.4893.en-de": 0.20783, 28 | "newstest2017.online-A.0.en-de": 0.00000, 29 | "newstest2017.online-B.0.en-de": 0.38100, 30 | "newstest2017.online-F.0.en-de": 0.00000, 31 | "newstest2017.online-G.0.en-de": 0.00000, 32 | "newstest2017.PROMT-Rule-based.4735.en-de": 0.00000, 33 | "newstest2017.RWTH-nmt-ensemble.4921.en-de": 0.01167, 34 | "newstest2017.SYSTRAN.4847.en-de": 0.20983, 35 | "newstest2017.TALP-UPC.4834.en-de": 0.00000, 36 | "newstest2017.uedin-nmt.4722.en-de": 0.00000, 37 | "newstest2017.xmu.4910.en-de": 0.25483, 38 | } 39 | 40 | # Obtained from the multeval toolkit, 10,000 AR trials, (BLEU and TER) 41 | # Code: github.com/mjclark/multeval.git 42 | MULTEVAL_P_VALS = { 43 | "newstest2017.C-3MA.4959.en-de": (0.0001, 0.0001), 44 | "newstest2017.FBK.4870.en-de": (0.0218, 0.09569), 45 | "newstest2017.KIT.4950.en-de": (0.0410, 0.0002), 46 | "newstest2017.LMU-nmt-reranked.4934.en-de": (0.09029, 0.0001), 47 | "newstest2017.LMU-nmt-single.4893.en-de": (0.58494, 0.0054), 48 | "newstest2017.online-A.0.en-de": (0.0001, 0.0001), 49 | "newstest2017.online-B.0.en-de": (0.94111, 0.82242), 50 | "newstest2017.online-F.0.en-de": (0.0001, 0.0001), 51 | "newstest2017.online-G.0.en-de": (0.0001, 0.0001), 52 | "newstest2017.PROMT-Rule-based.4735.en-de": (0.0001, 0.0001), 53 | "newstest2017.RWTH-nmt-ensemble.4921.en-de": (0.0207, 0.07539), 54 | "newstest2017.SYSTRAN.4847.en-de": (0.59914, 0.0001), 55 | "newstest2017.TALP-UPC.4834.en-de": (0.0001, 0.0001), 56 | "newstest2017.uedin-nmt.4722.en-de": (0.0001, 0.0001), 57 | "newstest2017.xmu.4910.en-de": (0.71073, 0.0001), 58 | } 59 | 60 | 61 | SACREBLEU_BS_P_VALS: DefaultDict[str, float] = defaultdict(float) 62 | SACREBLEU_AR_P_VALS: DefaultDict[str, float] = defaultdict(float) 63 | 64 | # Load data from pickled file to not bother with WMT17 downloading 65 | named_systems = _read_pickle_file() 66 | _, refs = named_systems.pop() 67 | metrics = {'BLEU': BLEU(references=refs, tokenize='none')} 68 | 69 | 70 | ######### 71 | # BS test 72 | ######### 73 | os.environ['SACREBLEU_SEED'] = str(12345) 74 | bs_scores = PairedTest( 75 | named_systems, metrics, references=None, 76 | test_type='bs', n_samples=2000)()[1] 77 | 78 | for name, result in zip(bs_scores['System'], bs_scores['BLEU']): 79 | assert isinstance(result, Result) 80 | if result.p_value is not None: 81 | assert isinstance(name, str) 82 | SACREBLEU_BS_P_VALS[name] += result.p_value 83 | 84 | 85 | ############################################### 86 | # AR test (1 run) 87 | # Test only BLEU as TER will take too much time 88 | ############################################### 89 | ar_scores = PairedTest(named_systems, metrics, references=None, 90 | test_type='ar', n_samples=10000)()[1] 91 | 92 | for name, result in zip(ar_scores['System'], ar_scores['BLEU']): 93 | assert isinstance(result, Result) 94 | if result.p_value is not None: 95 | assert isinstance(name, str) 96 | SACREBLEU_AR_P_VALS[name] += result.p_value 97 | 98 | 99 | @pytest.mark.parametrize("name, expected_p_val", MOSES_P_VALS.items()) 100 | def test_paired_bootstrap(name, expected_p_val): 101 | p_val = SACREBLEU_BS_P_VALS[name] 102 | assert abs(p_val - expected_p_val) < 1e-2 103 | 104 | 105 | @pytest.mark.parametrize("name, expected_p_vals", MULTEVAL_P_VALS.items()) 106 | def test_paired_approximate_randomization(name, expected_p_vals): 107 | expected_bleu_p_val = expected_p_vals[0] 108 | p_val = SACREBLEU_AR_P_VALS[name] 109 | assert abs(p_val - expected_bleu_p_val) < 1e-2 110 | -------------------------------------------------------------------------------- /sacrebleu/dataset/fake_sgml.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from ..utils import smart_open 5 | from .base import Dataset 6 | 7 | 8 | class FakeSGMLDataset(Dataset): 9 | """ 10 | The fake SGML format used by WMT prior to 2021. Can't be properly parsed. 11 | Source and reference(s) in separate files. 12 | """ 13 | 14 | def _convert_format(self, input_file_path, output_filep_path): 15 | """ 16 | Extract data from raw file and convert to raw txt format. 17 | """ 18 | with smart_open(input_file_path) as fin, smart_open( 19 | output_filep_path, "wt" 20 | ) as fout: 21 | for line in fin: 22 | if line.startswith("(.*).*?", "\\1", line)) 24 | print(line, file=fout) 25 | 26 | def _convert_meta(self, input_file_path, field, output_filep_path): 27 | """ 28 | Extract metadata from document tags, projects across segments. 29 | """ 30 | with smart_open(input_file_path) as fin, smart_open( 31 | output_filep_path, "wt" 32 | ) as fout: 33 | value = "" 34 | for line in fin: 35 | if line.startswith("= 2 88 | ), f"Each language pair in {self.name} must have at least 2 fields." 89 | 90 | fields = ["src"] 91 | 92 | if length == 2: 93 | fields.append("ref") 94 | else: 95 | for i, _ in enumerate(meta[langpair][1:]): 96 | fields.append(f"ref:{i}") 97 | 98 | if not self.name.startswith("wmt08"): 99 | fields += ["docid", "genre", "origlang"] 100 | 101 | return fields 102 | 103 | 104 | class WMTAdditionDataset(FakeSGMLDataset): 105 | """ 106 | Handle special case of WMT Google addition dataset. 107 | """ 108 | 109 | def _convert_format(self, input_file_path, output_filep_path): 110 | if input_file_path.endswith(".sgm"): 111 | return super()._convert_format(input_file_path, output_filep_path) 112 | else: 113 | with smart_open(input_file_path) as fin: 114 | with smart_open(output_filep_path, "wt") as fout: 115 | for line in fin: 116 | print(line.rstrip(), file=fout) 117 | -------------------------------------------------------------------------------- /test/test_chrf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You may not 4 | # use this file except in compliance with the License. A copy of the License 5 | # is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed on 10 | # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytest 15 | import sacrebleu 16 | 17 | EPSILON = 1e-4 18 | 19 | test_sentence_level_chrf = [ 20 | ( 21 | 'Co nás nejvíc trápí, protože lékaři si vybírají, kdo bude žít a kdo zemře.', 22 | ['Nejvíce smutní jsme z toho, že musíme rozhodovat o tom, kdo bude žít a kdo zemře.'], 23 | 39.14078509, 24 | ), 25 | ( 26 | 'Nebo prostě nemají vybavení, které by jim pomohlo, uvedli lékaři.', 27 | ['A někdy nemáme ani potřebný materiál, abychom jim pomohli, popsali lékaři.'], 28 | 31.22557079, 29 | ), 30 | ( 31 | 'Lapali po dechu, jejich životy skončily dřív, než skutečně začaly.', 32 | ['Lapali po dechu a pak jejich život skončil - dřív, než skutečně mohl začít, připomněli.'], 33 | 57.15704367, 34 | ), 35 | ] 36 | 37 | 38 | # hypothesis, reference, expected score 39 | # >= 2.0.0: some orders are not fulfilled in epsilon smoothing (chrF++.py and NLTK) 40 | test_cases = [ 41 | (["abcdefg"], ["hijklmnop"], 0.0), 42 | (["a"], ["b"], 0.0), 43 | ([""], ["b"], 0.0), 44 | ([""], ["ref"], 0.0), 45 | ([""], ["reference"], 0.0), 46 | (["aa"], ["ab"], 8.3333), 47 | (["a", "b"], ["a", "c"], 8.3333), 48 | (["a"], ["a"], 16.6667), 49 | (["a b c"], ["a b c"], 50.0), 50 | (["a b c"], ["abc"], 50.0), 51 | ([" risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ."], 52 | ["risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ."], 63.361730), 53 | ([" Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundlich. "], 54 | ["Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich."], 64.1302698), 55 | (["Niemand hat die Absicht, eine Mauer zu errichten"], ["Niemand hat die Absicht, eine Mauer zu errichten"], 100.0), 56 | ] 57 | 58 | # sacreBLEU < 2.0.0 mode 59 | # hypothesis, reference, expected score 60 | test_cases_effective_order = [ 61 | (["a"], ["a"], 100.0), 62 | ([""], ["reference"], 0.0), 63 | (["a b c"], ["a b c"], 100.0), 64 | (["a b c"], ["abc"], 100.0), 65 | ([""], ["c"], 0.0), 66 | (["a", "b"], ["a", "c"], 50.0), 67 | (["aa"], ["ab"], 25.0), 68 | ] 69 | 70 | test_cases_keep_whitespace = [ 71 | ( 72 | ["Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundlich."], 73 | ["Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich."], 74 | 67.3481606, 75 | ), 76 | ( 77 | ["risk assessment must be made of those who are qualified and expertise in the sector - these are the scientists ."], 78 | ["risk assessment has to be undertaken by those who are qualified and expert in that area - that is the scientists ."], 79 | 65.2414427, 80 | ), 81 | ] 82 | 83 | 84 | @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases) 85 | def test_chrf(hypotheses, references, expected_score): 86 | score = sacrebleu.corpus_chrf( 87 | hypotheses, [references], char_order=6, word_order=0, beta=3, 88 | eps_smoothing=True).score 89 | assert abs(score - expected_score) < EPSILON 90 | 91 | 92 | @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases_effective_order) 93 | def test_chrf_eff_order(hypotheses, references, expected_score): 94 | score = sacrebleu.corpus_chrf( 95 | hypotheses, [references], char_order=6, word_order=0, beta=3, 96 | eps_smoothing=False).score 97 | assert abs(score - expected_score) < EPSILON 98 | 99 | 100 | @pytest.mark.parametrize("hypotheses, references, expected_score", test_cases_keep_whitespace) 101 | def test_chrf_keep_whitespace(hypotheses, references, expected_score): 102 | score = sacrebleu.corpus_chrf( 103 | hypotheses, [references], char_order=6, word_order=0, beta=3, 104 | remove_whitespace=False).score 105 | assert abs(score - expected_score) < EPSILON 106 | 107 | 108 | @pytest.mark.parametrize("hypothesis, references, expected_score", test_sentence_level_chrf) 109 | def test_chrf_sentence_level(hypothesis, references, expected_score): 110 | score = sacrebleu.sentence_chrf(hypothesis, references, eps_smoothing=True).score 111 | assert abs(score - expected_score) < EPSILON 112 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_zh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017--2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You may not 4 | # use this file except in compliance with the License. A copy of the License 5 | # is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed on 10 | # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | ############## 15 | 16 | # MIT License 17 | # Copyright (c) 2017 - Shujian Huang 18 | 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | 26 | # The above copyright notice and this permission notice shall be included in 27 | # all copies or substantial portions of the Software. 28 | 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | 37 | # Author: Shujian Huang huangsj@nju.edu.cn 38 | 39 | 40 | from functools import lru_cache 41 | 42 | from .tokenizer_base import BaseTokenizer 43 | from .tokenizer_re import TokenizerRegexp 44 | 45 | _UCODE_RANGES = [ 46 | (u'\u3400', u'\u4db5'), # CJK Unified Ideographs Extension A, release 3.0 47 | (u'\u4e00', u'\u9fa5'), # CJK Unified Ideographs, release 1.1 48 | (u'\u9fa6', u'\u9fbb'), # CJK Unified Ideographs, release 4.1 49 | (u'\uf900', u'\ufa2d'), # CJK Compatibility Ideographs, release 1.1 50 | (u'\ufa30', u'\ufa6a'), # CJK Compatibility Ideographs, release 3.2 51 | (u'\ufa70', u'\ufad9'), # CJK Compatibility Ideographs, release 4.1 52 | (u'\u20000', u'\u2a6d6'), # (UTF16) CJK Unified Ideographs Extension B, release 3.1 53 | (u'\u2f800', u'\u2fa1d'), # (UTF16) CJK Compatibility Supplement, release 3.1 54 | (u'\uff00', u'\uffef'), # Full width ASCII, full width of English punctuation, 55 | # half width Katakana, half wide half width kana, Korean alphabet 56 | (u'\u2e80', u'\u2eff'), # CJK Radicals Supplement 57 | (u'\u3000', u'\u303f'), # CJK punctuation mark 58 | (u'\u31c0', u'\u31ef'), # CJK stroke 59 | (u'\u2f00', u'\u2fdf'), # Kangxi Radicals 60 | (u'\u2ff0', u'\u2fff'), # Chinese character structure 61 | (u'\u3100', u'\u312f'), # Phonetic symbols 62 | (u'\u31a0', u'\u31bf'), # Phonetic symbols (Taiwanese and Hakka expansion) 63 | (u'\ufe10', u'\ufe1f'), 64 | (u'\ufe30', u'\ufe4f'), 65 | (u'\u2600', u'\u26ff'), 66 | (u'\u2700', u'\u27bf'), 67 | (u'\u3200', u'\u32ff'), 68 | (u'\u3300', u'\u33ff'), 69 | ] 70 | 71 | 72 | class TokenizerZh(BaseTokenizer): 73 | 74 | def signature(self): 75 | return 'zh' 76 | 77 | def __init__(self): 78 | self._post_tokenizer = TokenizerRegexp() 79 | 80 | @staticmethod 81 | @lru_cache(maxsize=2**16) 82 | def _is_chinese_char(uchar): 83 | """ 84 | :param uchar: input char in unicode 85 | :return: whether the input char is a Chinese character. 86 | """ 87 | for start, end in _UCODE_RANGES: 88 | if start <= uchar <= end: 89 | return True 90 | return False 91 | 92 | @lru_cache(maxsize=2**16) 93 | def __call__(self, line): 94 | """The tokenization of Chinese text in this script contains two 95 | steps: separate each Chinese characters (by utf-8 encoding); tokenize 96 | the non Chinese part (following the `13a` i.e. mteval tokenizer). 97 | 98 | Author: Shujian Huang huangsj@nju.edu.cn 99 | 100 | :param line: input sentence 101 | :return: tokenized sentence 102 | """ 103 | 104 | line = line.strip() 105 | line_in_chars = "" 106 | 107 | # TODO: the below code could probably be replaced with the following: 108 | # @ozan: Gives slightly different scores, need to investigate 109 | # import regex 110 | # line = regex.sub(r'(\p{Han})', r' \1 ', line) 111 | for char in line: 112 | if self._is_chinese_char(char): 113 | line_in_chars += " " 114 | line_in_chars += char 115 | line_in_chars += " " 116 | else: 117 | line_in_chars += char 118 | 119 | return self._post_tokenizer(line_in_chars) 120 | -------------------------------------------------------------------------------- /test/test_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You may not 4 | # use this file except in compliance with the License. A copy of the License 5 | # is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed on 10 | # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from collections import namedtuple 15 | import pytest 16 | 17 | import sacrebleu 18 | 19 | from sacrebleu.metrics import BLEU 20 | 21 | 22 | EPSILON = 1e-8 23 | 24 | Statistics = namedtuple('Statistics', ['common', 'total']) 25 | 26 | test_raw_bleu_cases = [ 27 | # This now returns 0.0 score (#141) 28 | (["this is a test", "another test"], [["ref1", "ref2"]], 0.0), 29 | (["this is a test"], [["this is a test"]], 1.0), 30 | (["this is a fest"], [["this is a test"]], 0.223606797749979)] 31 | 32 | # test for README example with empty hypothesis strings check 33 | _refs = [ 34 | ['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'], 35 | ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'], 36 | ] 37 | 38 | _hyps = [ 39 | 'The dog bit the man.', 40 | "It wasn't surprising.", 41 | 'The man had just bitten him.', 42 | ] 43 | 44 | test_corpus_bleu_cases = [ 45 | (_hyps, _refs, {}, 48.530827), # test for default BLEU settings 46 | (('', '', ''), _refs, {}, 0.0), # ensure that empty hypotheses are not removed 47 | (_hyps, _refs, {'tokenize': 'none'}, 49.1919566), 48 | (_hyps, _refs, {'tokenize': '13a'}, 48.530827), 49 | (_hyps, _refs, {'tokenize': 'intl'}, 43.91623493), 50 | (_hyps, _refs, {'smooth_method': 'none'}, 48.530827), 51 | ] 52 | 53 | test_case_offset = [(["am I am a character sequence"], [["I am a symbol string sequence a a"]], 0.1555722182, 0)] 54 | 55 | # statistic structure: 56 | # - common counts 57 | # - total counts 58 | # - hyp_count 59 | # - ref_count 60 | 61 | test_case_statistics = [(["am I am a character sequence"], [["I am a symbol string sequence a a"]], 62 | Statistics([4, 2, 1, 0], [6, 5, 4, 3]))] 63 | 64 | test_case_scoring = [((Statistics([9, 7, 5, 3], [10, 8, 6, 4]), 11, 11), 0.8375922397)] 65 | 66 | test_case_effective_order = [(["test"], [["a test"]], 0.3678794411714425), 67 | (["a test"], [["a test"]], 1.0), 68 | (["a little test"], [["a test"]], 0.03218297948685433)] 69 | 70 | 71 | # testing that right score is returned for null statistics and different offsets 72 | # format: stat, offset, expected score 73 | test_case_degenerate_stats = [((Statistics([0, 0, 0, 0], [4, 4, 2, 1]), 0, 1), 0.0, 0.0), 74 | ((Statistics([0, 0, 0, 0], [10, 11, 12, 0]), 14, 10), 0.0, 0.0), 75 | ((Statistics([0, 0, 0, 0], [0, 0, 0, 0]), 0, 0), 0.0, 0.0), 76 | ((Statistics([6, 5, 4, 0], [6, 5, 4, 3]), 6, 6), 0.0, 0.0), 77 | ((Statistics([0, 0, 0, 0], [0, 0, 0, 0]), 0, 0), 0.1, 0.0), 78 | ((Statistics([0, 0, 0, 0], [0, 0, 0, 0]), 1, 5), 0.01, 0.0)] 79 | 80 | 81 | @pytest.mark.parametrize("hypotheses, references, expected_bleu", test_raw_bleu_cases) 82 | def test_raw_bleu(hypotheses, references, expected_bleu): 83 | bleu = sacrebleu.raw_corpus_bleu(hypotheses, references, .01).score / 100 84 | assert abs(bleu - expected_bleu) < EPSILON 85 | 86 | 87 | @pytest.mark.parametrize("hypotheses, references, kwargs, expected_bleu", test_corpus_bleu_cases) 88 | def test_corpus_bleu(hypotheses, references, kwargs, expected_bleu): 89 | bleu = sacrebleu.corpus_bleu(hypotheses, references, **kwargs).score 90 | assert abs(bleu - expected_bleu) < EPSILON 91 | 92 | 93 | @pytest.mark.parametrize("hypotheses, references, expected_bleu", test_case_effective_order) 94 | def test_effective_order(hypotheses, references, expected_bleu): 95 | bleu = sacrebleu.raw_corpus_bleu(hypotheses, references, .01).score / 100 96 | assert abs(bleu - expected_bleu) < EPSILON 97 | 98 | 99 | @pytest.mark.parametrize("hypothesis, reference, expected_stat", test_case_statistics) 100 | def test_statistics(hypothesis, reference, expected_stat): 101 | result = sacrebleu.raw_corpus_bleu(hypothesis, reference, .01) 102 | stat = Statistics(result.counts, result.totals) 103 | assert stat == expected_stat 104 | 105 | 106 | @pytest.mark.parametrize("statistics, expected_score", test_case_scoring) 107 | def test_scoring(statistics, expected_score): 108 | score = BLEU.compute_bleu(statistics[0].common, statistics[0].total, statistics[1], statistics[2]).score / 100 109 | assert abs(score - expected_score) < EPSILON 110 | 111 | 112 | @pytest.mark.parametrize("hypothesis, reference, expected_with_offset, expected_without_offset", 113 | test_case_offset) 114 | def test_offset(hypothesis, reference, expected_with_offset, expected_without_offset): 115 | score_without_offset = sacrebleu.raw_corpus_bleu(hypothesis, reference, 0.0).score / 100 116 | assert abs(expected_without_offset - score_without_offset) < EPSILON 117 | 118 | # let it use BLEU's internal default of 0.1 through passing `None` 119 | score_with_offset = sacrebleu.raw_corpus_bleu(hypothesis, reference, None).score / 100 120 | assert abs(expected_with_offset - score_with_offset) < EPSILON 121 | 122 | # let it use BLEU's internal default of 0.1 123 | score_with_offset = sacrebleu.raw_corpus_bleu(hypothesis, reference).score / 100 124 | assert abs(expected_with_offset - score_with_offset) < EPSILON 125 | 126 | 127 | @pytest.mark.parametrize("statistics, offset, expected_score", test_case_degenerate_stats) 128 | def test_degenerate_statistics(statistics, offset, expected_score): 129 | score = BLEU.compute_bleu( 130 | statistics[0].common, 131 | statistics[0].total, 132 | statistics[1], 133 | statistics[2], 134 | smooth_method='floor', smooth_value=offset).score / 100 135 | assert score == expected_score 136 | 137 | 138 | test_bleu_max_order = [ 139 | (1, _hyps, _refs, "77.65"), 140 | (2, _hyps, _refs, "60.50"), 141 | (3, _hyps, _refs, "53.93"), 142 | (4, _hyps, _refs, "48.53"), 143 | (5, _hyps, _refs, "46.14"), 144 | (6, _hyps, _refs, "43.28"), 145 | ] 146 | 147 | 148 | @pytest.mark.parametrize("order, hyps, refs, expected_bleu", test_bleu_max_order) 149 | def test_max_ngram_order(order, hyps, refs, expected_bleu): 150 | bleu = BLEU(max_ngram_order=order).corpus_score(hyps, refs) 151 | assert f"{bleu.score:.2f}" == expected_bleu 152 | -------------------------------------------------------------------------------- /sacrebleu/tokenizers/tokenizer_ter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Memsource 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import re 17 | from functools import lru_cache 18 | 19 | from .tokenizer_base import BaseTokenizer 20 | 21 | 22 | def _normalize_general_and_western(sent: str) -> str: 23 | # language-independent (general) part 24 | 25 | # strip end-of-line hyphenation and join lines 26 | sent = re.sub(r"\n-", "", sent) 27 | 28 | # join lines 29 | sent = re.sub(r"\n", " ", sent) 30 | 31 | # handle XML escaped symbols 32 | sent = re.sub(r""", "\"", sent) 33 | sent = re.sub(r"&", "&", sent) 34 | sent = re.sub(r"<", "<", sent) 35 | sent = re.sub(r">", ">", sent) 36 | 37 | # language-dependent (Western) part 38 | sent = f" {sent} " 39 | 40 | # tokenize punctuation 41 | sent = re.sub(r"([{-~[-` -&(-+:-@/])", r" \1 ", sent) 42 | 43 | # handle possesives 44 | sent = re.sub(r"'s ", r" 's ", sent) 45 | sent = re.sub(r"'s$", r" 's", sent) 46 | 47 | # tokenize period and comma unless preceded by a digit 48 | sent = re.sub(r"([^0-9])([\.,])", r"\1 \2 ", sent) 49 | 50 | # tokenize period and comma unless followed by a digit 51 | sent = re.sub(r"([\.,])([^0-9])", r" \1 \2", sent) 52 | 53 | # tokenize dash when preceded by a digit 54 | sent = re.sub(r"([0-9])(-)", r"\1 \2 ", sent) 55 | 56 | return sent 57 | 58 | 59 | def _normalize_asian(sent: str) -> str: 60 | # Split Chinese chars and Japanese kanji down to character level 61 | 62 | # 4E00—9FFF CJK Unified Ideographs 63 | # 3400—4DBF CJK Unified Ideographs Extension A 64 | sent = re.sub(r"([\u4e00-\u9fff\u3400-\u4dbf])", r" \1 ", sent) 65 | 66 | # 31C0—31EF CJK Strokes 67 | # 2E80—2EFF CJK Radicals Supplement 68 | sent = re.sub(r"([\u31c0-\u31ef\u2e80-\u2eff])", r" \1 ", sent) 69 | 70 | # 3300—33FF CJK Compatibility 71 | # F900—FAFF CJK Compatibility Ideographs 72 | # FE30—FE4F CJK Compatibility Forms 73 | sent = re.sub( 74 | r"([\u3300-\u33ff\uf900-\ufaff\ufe30-\ufe4f])", r" \1 ", sent) 75 | 76 | # 3200—32FF Enclosed CJK Letters and Months 77 | sent = re.sub(r"([\u3200-\u3f22])", r" \1 ", sent) 78 | 79 | # Split Hiragana, Katakana, and KatakanaPhoneticExtensions 80 | # only when adjacent to something else 81 | # 3040—309F Hiragana 82 | # 30A0—30FF Katakana 83 | # 31F0—31FF Katakana Phonetic Extensions 84 | sent = re.sub( 85 | r"(^|^[\u3040-\u309f])([\u3040-\u309f]+)(?=$|^[\u3040-\u309f])", 86 | r"\1 \2 ", sent) 87 | sent = re.sub( 88 | r"(^|^[\u30a0-\u30ff])([\u30a0-\u30ff]+)(?=$|^[\u30a0-\u30ff])", 89 | r"\1 \2 ", sent) 90 | sent = re.sub( 91 | r"(^|^[\u31f0-\u31ff])([\u31f0-\u31ff]+)(?=$|^[\u31f0-\u31ff])", 92 | r"\1 \2 ", sent) 93 | 94 | sent = re.sub(TercomTokenizer.ASIAN_PUNCT, r" \1 ", sent) 95 | sent = re.sub(TercomTokenizer.FULL_WIDTH_PUNCT, r" \1 ", sent) 96 | return sent 97 | 98 | 99 | def _remove_punct(sent: str) -> str: 100 | return re.sub(r"[\.,\?:;!\"\(\)]", "", sent) 101 | 102 | 103 | def _remove_asian_punct(sent: str) -> str: 104 | sent = re.sub(TercomTokenizer.ASIAN_PUNCT, r"", sent) 105 | sent = re.sub(TercomTokenizer.FULL_WIDTH_PUNCT, r"", sent) 106 | return sent 107 | 108 | 109 | class TercomTokenizer(BaseTokenizer): 110 | """Re-implementation of Tercom Tokenizer in Python 3. 111 | 112 | See src/ter/core/Normalizer.java in https://github.com/jhclark/tercom 113 | 114 | Note that Python doesn't support named Unicode blocks so the mapping for 115 | relevant blocks was taken from here: 116 | 117 | https://unicode-table.com/en/blocks/ 118 | """ 119 | ASIAN_PUNCT = r"([\u3001\u3002\u3008-\u3011\u3014-\u301f\uff61-\uff65\u30fb])" 120 | FULL_WIDTH_PUNCT = r"([\uff0e\uff0c\uff1f\uff1a\uff1b\uff01\uff02\uff08\uff09])" 121 | 122 | def __init__(self, 123 | normalized: bool = False, 124 | no_punct: bool = False, 125 | asian_support: bool = False, 126 | case_sensitive: bool = False): 127 | """Initialize the tokenizer. 128 | 129 | :param normalized: Enable character normalization. By default, normalizes a couple of things such as 130 | newlines being stripped, retrieving XML encoded characters, and fixing tokenization for punctuation. When 131 | 'asian_support' is enabled, also normalizes specific Asian (CJK) character sequences, i.e. 132 | split them down to the character level. 133 | :param no_punct: Remove punctuation. Can be used in conjunction with 'asian_support' to also remove typical 134 | punctuation markers in Asian languages (CJK). 135 | :param asian_support: Enable special treatment of Asian characters. This option only has an effect when 136 | 'normalized' and/or 'no_punct' is enabled. If 'normalized' is also enabled, then Asian (CJK) 137 | characters are split down to the character level. If 'no_punct' is enabled alongside 'asian_support', 138 | specific unicode ranges for CJK and full-width punctuations are also removed. 139 | :param case_sensitive: Enable case sensitivity, i.e., do not lower case data. 140 | """ 141 | self._normalized = normalized 142 | self._no_punct = no_punct 143 | self._asian_support = asian_support 144 | self._case_sensitive = case_sensitive 145 | 146 | @lru_cache(maxsize=2**16) 147 | # Although the cache is shared across different instances, same sentence 148 | # queries do not return invalid returns across different instances since 149 | # `self` becomes part of the query as well. 150 | def __call__(self, sent: str) -> str: 151 | if not sent: 152 | return "" 153 | 154 | if not self._case_sensitive: 155 | sent = sent.lower() 156 | 157 | if self._normalized: 158 | sent = _normalize_general_and_western(sent) 159 | if self._asian_support: 160 | sent = _normalize_asian(sent) 161 | 162 | if self._no_punct: 163 | sent = _remove_punct(sent) 164 | if self._asian_support: 165 | sent = _remove_asian_punct(sent) 166 | 167 | # Strip extra whitespaces 168 | return ' '.join(sent.split()) 169 | 170 | def signature(self): 171 | return 'tercom' 172 | -------------------------------------------------------------------------------- /sacrebleu/dataset/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | The base class for all types of datasets. 3 | """ 4 | import os 5 | import re 6 | from abc import ABCMeta, abstractmethod 7 | from typing import Dict, List, Optional 8 | 9 | from ..utils import SACREBLEU_DIR, download_file, smart_open 10 | 11 | 12 | class Dataset(metaclass=ABCMeta): 13 | def __init__( 14 | self, 15 | name: str, 16 | data: Optional[List[str]] = None, 17 | description: Optional[str] = None, 18 | citation: Optional[str] = None, 19 | md5: Optional[List[str]] = None, 20 | langpairs=Dict[str, List[str]], 21 | **kwargs, 22 | ): 23 | """ 24 | Params come from the values in DATASETS. 25 | 26 | :param name: Name of the dataset. 27 | :param data: URL of the raw data of the dataset. 28 | :param description: Description of the dataset. 29 | :param citation: Citation for the dataset. 30 | :param md5: MD5 checksum of the dataset. 31 | :param langpairs: List of available language pairs. 32 | """ 33 | self.name = name 34 | self.data = data 35 | self.description = description 36 | self.citation = citation 37 | self.md5 = md5 38 | self.langpairs = langpairs 39 | self.kwargs = kwargs 40 | 41 | # Don't do any downloading or further processing now. 42 | # Only do that lazily, when asked. 43 | 44 | # where to store the dataset 45 | self._outdir = os.path.join(SACREBLEU_DIR, self.name) 46 | self._rawdir = os.path.join(self._outdir, "raw") 47 | 48 | def maybe_download(self): 49 | """ 50 | If the dataset isn't downloaded, use utils/download_file() 51 | This can be implemented here in the base class. It should write 52 | to ~/.sacreleu/DATASET/raw exactly as it does now. 53 | """ 54 | os.makedirs(self._rawdir, exist_ok=True) 55 | 56 | expected_checksums = self.md5 if self.md5 else [None] * len(self.data) 57 | 58 | for url, expected_md5 in zip(self.data, expected_checksums): 59 | tarball = os.path.join(self._rawdir, self._get_tarball_filename(url)) 60 | 61 | download_file( 62 | url, tarball, extract_to=self._rawdir, expected_md5=expected_md5 63 | ) 64 | 65 | @staticmethod 66 | def _clean(s): 67 | """ 68 | Removes trailing and leading spaces and collapses multiple consecutive internal spaces to a single one. 69 | 70 | :param s: The string. 71 | :return: A cleaned-up string. 72 | """ 73 | return re.sub(r"\s+", " ", s.strip()) 74 | 75 | def _get_tarball_filename(self, url): 76 | """ 77 | Produces a local filename for tarball. 78 | :param url: The url to download. 79 | :return: A name produced from the dataset identifier and the URL basename. 80 | """ 81 | return self.name.replace("/", "_") + "." + os.path.basename(url) 82 | 83 | def _get_txt_file_path(self, langpair, fieldname): 84 | """ 85 | Given the language pair and fieldname, return the path to the text file. 86 | The format is: ~/.sacrebleu/DATASET/DATASET.LANGPAIR.FIELDNAME 87 | 88 | :param langpair: The language pair. 89 | :param fieldname: The fieldname. 90 | :return: The path to the text file. 91 | """ 92 | # handle the special case of subsets. e.g. "wmt21/dev" > "wmt21_dev" 93 | name = self.name.replace("/", "_") 94 | # Colons are used to distinguish multiple references, but are not supported in Windows filenames 95 | fieldname = fieldname.replace(":", "-") 96 | return os.path.join(self._outdir, f"{name}.{langpair}.{fieldname}") 97 | 98 | def _get_langpair_metadata(self, langpair): 99 | """ 100 | Given a language pair, return the metadata for that language pair. 101 | Deal with errors if the language pair is not available. 102 | 103 | :param langpair: The language pair. e.g. "en-de" 104 | :return: Dict format which is same as self.langpairs. 105 | """ 106 | if langpair is None: 107 | langpairs = self.langpairs 108 | elif langpair not in self.langpairs: 109 | raise Exception(f"No such language pair {self.name}/{langpair}") 110 | else: 111 | langpairs = {langpair: self.langpairs[langpair]} 112 | 113 | return langpairs 114 | 115 | @abstractmethod 116 | def process_to_text(self, langpair=None) -> None: 117 | """Processes raw files to plain text files. 118 | 119 | :param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed. 120 | """ 121 | pass 122 | 123 | def fieldnames(self, langpair) -> List[str]: 124 | """ 125 | Return a list of all the field names. For most source, this is just 126 | the source and the reference. For others, it might include the document 127 | ID for each line, or the original language (origLang). 128 | 129 | get_files() should return the same number of items as this. 130 | 131 | :param langpair: The language pair (e.g., "de-en") 132 | :return: a list of field names 133 | """ 134 | return ["src", "ref"] 135 | 136 | def __iter__(self, langpair): 137 | """ 138 | Iterates over all fields (source, references, and other metadata) defined 139 | by the dataset. 140 | """ 141 | all_files = self.get_files(langpair) 142 | all_fins = [smart_open(f) for f in all_files] 143 | 144 | for item in zip(*all_fins): 145 | yield item 146 | 147 | def source(self, langpair): 148 | """ 149 | Return an iterable over the source lines. 150 | """ 151 | source_file = self.get_source_file(langpair) 152 | with smart_open(source_file) as fin: 153 | for line in fin: 154 | yield line.strip() 155 | 156 | def references(self, langpair): 157 | """ 158 | Return an iterable over the references. 159 | """ 160 | ref_files = self.get_reference_files(langpair) 161 | ref_fins = [smart_open(f) for f in ref_files] 162 | 163 | for item in zip(*ref_fins): 164 | yield item 165 | 166 | def get_source_file(self, langpair): 167 | all_files = self.get_files(langpair) 168 | all_fields = self.fieldnames(langpair) 169 | index = all_fields.index("src") 170 | return all_files[index] 171 | 172 | def get_reference_files(self, langpair): 173 | all_files = self.get_files(langpair) 174 | all_fields = self.fieldnames(langpair) 175 | ref_files = [ 176 | f for f, field in zip(all_files, all_fields) if field.startswith("ref") 177 | ] 178 | return ref_files 179 | 180 | def get_files(self, langpair): 181 | """ 182 | Returns the path of the source file and all reference files for 183 | the provided test set / language pair. 184 | Downloads the references first if they are not already local. 185 | 186 | :param langpair: The language pair (e.g., "de-en") 187 | :return: a list of the source file and all reference files 188 | """ 189 | fields = self.fieldnames(langpair) 190 | files = [self._get_txt_file_path(langpair, field) for field in fields] 191 | 192 | for file in files: 193 | if not os.path.exists(file): 194 | self.process_to_text(langpair) 195 | return files 196 | -------------------------------------------------------------------------------- /sacrebleu/metrics/ter.py: -------------------------------------------------------------------------------- 1 | """The implementation of the TER metric (Snover et al., 2006).""" 2 | 3 | # Copyright 2020 Memsource 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from typing import List, Dict, Sequence, Optional, Any 19 | 20 | from ..tokenizers.tokenizer_ter import TercomTokenizer 21 | from ..utils import sum_of_lists 22 | from .base import Score, Signature, Metric 23 | from .lib_ter import translation_edit_rate 24 | 25 | 26 | class TERSignature(Signature): 27 | """A convenience class to represent the reproducibility signature for TER. 28 | 29 | :param args: key-value dictionary passed from the actual metric instance. 30 | """ 31 | def __init__(self, args: dict): 32 | """`TERSignature` initializer.""" 33 | super().__init__(args) 34 | self._abbr.update({ 35 | 'case': 'c', 36 | 'tok': 't', 37 | 'norm': 'nr', 38 | 'punct': 'pn', 39 | 'asian': 'as', 40 | }) 41 | 42 | self.info.update({ 43 | 'case': 'mixed' if args['case_sensitive'] else 'lc', 44 | 'tok': args['tokenizer_signature'], 45 | 'norm': args['normalized'], 46 | 'punct': not args['no_punct'], 47 | 'asian': args['asian_support'], 48 | }) 49 | 50 | 51 | class TERScore(Score): 52 | """A convenience class to represent TER scores. 53 | 54 | :param score: The TER score. 55 | :param num_edits: The cumulative number of edits. 56 | :param ref_length: The cumulative average reference length. 57 | """ 58 | def __init__(self, score: float, num_edits: float, ref_length: float): 59 | """`TERScore` initializer.""" 60 | super().__init__('TER', score) 61 | self.num_edits = int(num_edits) 62 | self.ref_length = ref_length 63 | 64 | 65 | class TER(Metric): 66 | """Translation edit rate (TER). A near-exact reimplementation of the Tercom 67 | algorithm, produces identical results on all "sane" outputs. 68 | 69 | Tercom original implementation: https://github.com/jhclark/tercom 70 | 71 | The beam edit distance algorithm uses a slightly different approach (we stay 72 | around the diagonal which is faster, at least in Python) so in some 73 | (extreme) corner cases, the output could differ. 74 | 75 | Caching in the edit distance is based partly on the PyTer package by Hiroyuki 76 | Tanaka (MIT license). (https://github.com/aflc/pyter) 77 | 78 | :param normalized: Enable character normalization. By default, normalizes a couple of things such as 79 | newlines being stripped, retrieving XML encoded characters, and fixing tokenization for punctuation. When 80 | 'asian_support' is enabled, also normalizes specific Asian (CJK) character sequences, i.e. 81 | split them down to the character level. 82 | :param no_punct: Remove punctuation. Can be used in conjunction with 'asian_support' to also remove typical 83 | punctuation markers in Asian languages (CJK). 84 | :param asian_support: Enable special treatment of Asian characters. This option only has an effect when 85 | 'normalized' and/or 'no_punct' is enabled. If 'normalized' is also enabled, then Asian (CJK) 86 | characters are split down to the character level. If 'no_punct' is enabled alongside 'asian_support', 87 | specific unicode ranges for CJK and full-width punctuations are also removed. 88 | :param case_sensitive: If `True`, does not lowercase sentences. 89 | :param references: A sequence of reference documents with document being 90 | defined as a sequence of reference strings. If given, the reference info 91 | will be pre-computed and cached for faster re-computation across many systems. 92 | """ 93 | 94 | _SIGNATURE_TYPE = TERSignature 95 | 96 | def __init__(self, normalized: bool = False, 97 | no_punct: bool = False, 98 | asian_support: bool = False, 99 | case_sensitive: bool = False, 100 | references: Optional[Sequence[Sequence[str]]] = None): 101 | """`TER` initializer.""" 102 | super().__init__() 103 | 104 | self.no_punct = no_punct 105 | self.normalized = normalized 106 | self.asian_support = asian_support 107 | self.case_sensitive = case_sensitive 108 | 109 | self.tokenizer = TercomTokenizer( 110 | normalized=self.normalized, 111 | no_punct=self.no_punct, 112 | asian_support=self.asian_support, 113 | case_sensitive=self.case_sensitive, 114 | ) 115 | self.tokenizer_signature = self.tokenizer.signature() 116 | 117 | if references is not None: 118 | self._ref_cache = self._cache_references(references) 119 | 120 | def _preprocess_segment(self, sent: str) -> str: 121 | """Given a sentence, apply tokenization if enabled. 122 | 123 | :param sent: The input sentence string. 124 | :return: The pre-processed output string. 125 | """ 126 | return self.tokenizer(sent.rstrip()) 127 | 128 | def _compute_score_from_stats(self, stats: List[float]) -> TERScore: 129 | """Computes the final score from already aggregated statistics. 130 | 131 | :param stats: A list or numpy array of segment-level statistics. 132 | :return: A `TERScore` object. 133 | """ 134 | total_edits, sum_ref_lengths = stats[0], stats[1] 135 | 136 | if sum_ref_lengths > 0: 137 | score = total_edits / sum_ref_lengths 138 | elif total_edits > 0: 139 | score = 1.0 # empty reference(s) and non-empty hypothesis 140 | else: 141 | score = 0.0 # both reference(s) and hypothesis are empty 142 | 143 | return TERScore(100 * score, total_edits, sum_ref_lengths) 144 | 145 | def _aggregate_and_compute(self, stats: List[List[float]]) -> TERScore: 146 | """Computes the final TER score given the pre-computed corpus statistics. 147 | 148 | :param stats: A list of segment-level statistics 149 | :return: A `TERScore` instance. 150 | """ 151 | return self._compute_score_from_stats(sum_of_lists(stats)) 152 | 153 | def _compute_segment_statistics( 154 | self, hypothesis: str, ref_kwargs: Dict) -> List[float]: 155 | """Given a (pre-processed) hypothesis sentence and already computed 156 | reference words, returns the segment statistics required to compute 157 | the full TER score. 158 | 159 | :param hypothesis: Hypothesis sentence. 160 | :param ref_kwargs: A dictionary with `ref_words` key which is a list 161 | where each sublist contains reference words. 162 | :return: A two-element list that contains the 'minimum number of edits' 163 | and 'the average reference length'. 164 | """ 165 | 166 | ref_lengths = 0 167 | best_num_edits = int(1e16) 168 | 169 | words_hyp = hypothesis.split() 170 | 171 | # Iterate the references 172 | ref_words = ref_kwargs['ref_words'] 173 | for words_ref in ref_words: 174 | num_edits, ref_len = translation_edit_rate(words_hyp, words_ref) 175 | ref_lengths += ref_len 176 | if num_edits < best_num_edits: 177 | best_num_edits = num_edits 178 | 179 | avg_ref_len = ref_lengths / len(ref_words) 180 | return [best_num_edits, avg_ref_len] 181 | 182 | def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, Any]: 183 | """Given a list of reference segments, applies pre-processing & tokenization 184 | and returns list of tokens for each reference. 185 | 186 | :param refs: A sequence of strings. 187 | :return: A dictionary that will be passed to `_compute_segment_statistics()` 188 | through keyword arguments. 189 | """ 190 | ref_words = [] 191 | 192 | for ref in refs: 193 | ref_words.append(self._preprocess_segment(ref).split()) 194 | 195 | return {'ref_words': ref_words} 196 | -------------------------------------------------------------------------------- /sacrebleu/dataset/wmt_xml.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lxml.etree as ET 4 | 5 | from ..utils import smart_open 6 | from .base import Dataset 7 | 8 | from collections import defaultdict 9 | 10 | 11 | def _get_field_by_translator(translator): 12 | if not translator: 13 | return "ref" 14 | else: 15 | return f"ref:{translator}" 16 | 17 | class WMTXMLDataset(Dataset): 18 | """ 19 | The 2021+ WMT dataset format. Everything is contained in a single file. 20 | Can be parsed with the lxml parser. 21 | """ 22 | @staticmethod 23 | def _unwrap_wmt21_or_later(raw_file): 24 | """ 25 | Unwraps the XML file from wmt21 or later. 26 | This script is adapted from https://github.com/wmt-conference/wmt-format-tools 27 | 28 | :param raw_file: The raw xml file to unwrap. 29 | :return: Dictionary which contains the following fields 30 | (each a list with values for each sentence): 31 | - `src`: The source sentences. 32 | - `docid`: ID indicating which document the sentences belong to. 33 | - `origlang`: The original language of the document. 34 | - `domain`: Domain of the document. 35 | - `ref:{translator}`: The references produced by each translator. 36 | - `ref`: An alias for the references from the first translator. 37 | """ 38 | tree = ET.parse(raw_file) 39 | # Find and check the documents (src, ref, hyp) 40 | src_langs, ref_langs, translators = set(), set(), set() 41 | for src_doc in tree.getroot().findall(".//src"): 42 | src_langs.add(src_doc.get("lang")) 43 | 44 | for ref_doc in tree.getroot().findall(".//ref"): 45 | ref_langs.add(ref_doc.get("lang")) 46 | translator = ref_doc.get("translator") 47 | translators.add(translator) 48 | 49 | assert ( 50 | len(src_langs) == 1 51 | ), f"Multiple source languages found in the file: {raw_file}" 52 | assert ( 53 | len(ref_langs) == 1 54 | ), f"Found {len(ref_langs)} reference languages found in the file: {raw_file}" 55 | 56 | src = [] 57 | docids = [] 58 | orig_langs = [] 59 | domains = [] 60 | 61 | refs = { _get_field_by_translator(translator): [] for translator in translators } 62 | 63 | systems = defaultdict(list) 64 | 65 | src_sent_count, doc_count, seen_domain = 0, 0, False 66 | for doc in tree.getroot().findall(".//doc"): 67 | # Skip the testsuite 68 | if "testsuite" in doc.attrib: 69 | continue 70 | 71 | doc_count += 1 72 | src_sents = { 73 | int(seg.get("id")): seg.text for seg in doc.findall(".//src//seg") 74 | } 75 | 76 | def get_sents(doc): 77 | return { 78 | int(seg.get("id")): seg.text if seg.text else "" 79 | for seg in doc.findall(".//seg") 80 | } 81 | 82 | ref_docs = doc.findall(".//ref") 83 | 84 | trans_to_ref = { 85 | ref_doc.get("translator"): get_sents(ref_doc) for ref_doc in ref_docs 86 | } 87 | 88 | hyp_docs = doc.findall(".//hyp") 89 | hyps = { 90 | hyp_doc.get("system"): get_sents(hyp_doc) for hyp_doc in hyp_docs 91 | } 92 | 93 | for seg_id in sorted(src_sents.keys()): 94 | # no ref translation is available for this segment 95 | if not any([value.get(seg_id, "") for value in trans_to_ref.values()]): 96 | continue 97 | for translator in translators: 98 | refs[_get_field_by_translator(translator)].append( 99 | trans_to_ref.get(translator, {translator: {}}).get(seg_id, "") 100 | ) 101 | src.append(src_sents[seg_id]) 102 | for system_name in hyps.keys(): 103 | systems[system_name].append(hyps[system_name][seg_id]) 104 | docids.append(doc.attrib["id"]) 105 | orig_langs.append(doc.attrib["origlang"]) 106 | # The "domain" attribute is missing in WMT21 and WMT22 107 | domains.append(doc.get("domain")) 108 | seen_domain = doc.get("domain") is not None 109 | src_sent_count += 1 110 | 111 | fields = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} 112 | if seen_domain: 113 | fields["domain"] = domains 114 | return fields 115 | 116 | def _get_langpair_path(self, langpair): 117 | """ 118 | Returns the path for this language pair. 119 | This is useful because in WMT22, the language-pair data structure can be a dict, 120 | in order to allow for overriding which test set to use. 121 | """ 122 | langpair_data = self._get_langpair_metadata(langpair)[langpair] 123 | rel_path = langpair_data["path"] if isinstance(langpair_data, dict) else langpair_data[0] 124 | return os.path.join(self._rawdir, rel_path) 125 | 126 | def process_to_text(self, langpair=None): 127 | """Processes raw files to plain text files. 128 | 129 | :param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed. 130 | """ 131 | # ensure that the dataset is downloaded 132 | self.maybe_download() 133 | 134 | for langpair in sorted(self._get_langpair_metadata(langpair).keys()): 135 | # The data type can be a list of paths, or a dict, containing the "path" 136 | # and an override on which labeled reference to use (key "refs") 137 | rawfile = self._get_langpair_path(langpair) 138 | 139 | with smart_open(rawfile) as fin: 140 | fields = self._unwrap_wmt21_or_later(fin) 141 | 142 | for fieldname in fields: 143 | textfile = self._get_txt_file_path(langpair, fieldname) 144 | 145 | # skip if the file already exists 146 | if os.path.exists(textfile) and os.path.getsize(textfile) > 0: 147 | continue 148 | 149 | with smart_open(textfile, "w") as fout: 150 | for line in fields[fieldname]: 151 | print(self._clean(line), file=fout) 152 | 153 | def _get_langpair_allowed_refs(self, langpair): 154 | """ 155 | Returns the preferred references for this language pair. 156 | This can be set in the language pair block (as in WMT22), and backs off to the 157 | test-set-level default, or nothing. 158 | 159 | There is one exception. In the metadata, sometimes there is no translator field 160 | listed (e.g., wmt22:liv-en). In this case, the reference is set to "", and the 161 | field "ref" is returned. 162 | """ 163 | defaults = self.kwargs.get("refs", []) 164 | langpair_data = self._get_langpair_metadata(langpair)[langpair] 165 | if isinstance(langpair_data, dict): 166 | allowed_refs = langpair_data.get("refs", defaults) 167 | else: 168 | allowed_refs = defaults 169 | allowed_refs = [_get_field_by_translator(ref) for ref in allowed_refs] 170 | 171 | return allowed_refs 172 | 173 | def get_reference_files(self, langpair): 174 | """ 175 | Returns the requested reference files. 176 | This is defined as a default at the test-set level, and can be overridden per language. 177 | """ 178 | # Iterate through the (label, file path) pairs, looking for permitted labels 179 | allowed_refs = self._get_langpair_allowed_refs(langpair) 180 | all_files = self.get_files(langpair) 181 | all_fields = self.fieldnames(langpair) 182 | ref_files = [ 183 | f for f, field in zip(all_files, all_fields) if field in allowed_refs 184 | ] 185 | return ref_files 186 | 187 | def fieldnames(self, langpair): 188 | """ 189 | Return a list of all the field names. For most source, this is just 190 | the source and the reference. For others, it might include the document 191 | ID for each line, or the original language (origLang). 192 | 193 | get_files() should return the same number of items as this. 194 | 195 | :param langpair: The language pair (e.g., "de-en") 196 | :return: a list of field names 197 | """ 198 | self.maybe_download() 199 | rawfile = self._get_langpair_path(langpair) 200 | 201 | with smart_open(rawfile) as fin: 202 | fields = self._unwrap_wmt21_or_later(fin) 203 | 204 | return list(fields.keys()) 205 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | | Dataset | Description | 2 | | ------------------------------ | ------------------------------------------------------------------------------------------------------------------- | 3 | | mtedx/valid | mTEDx evaluation data, valid: [URL](http://openslr.org/100) | 4 | | mtedx/test | mTEDx evaluation data, test: [URL](http://openslr.org/100) | 5 | | wmt23 | Official evaluation and system data for WMT23. | 6 | | wmt22 | Official evaluation and system data for WMT22. | 7 | | wmt21/systems | WMT21 system output. | 8 | | wmt21/dev | Development data for WMT21,if multiple references are available, the first one is used. | 9 | | wmt21/D | Official evaluation data for WMT21 with reference D | 10 | | wmt21/C | Official evaluation data for WMT21 with reference C | 11 | | wmt21/B | Official evaluation data for WMT21 with reference B. | 12 | | wmt21/AC | Official evaluation data for WMT21 with references A and C | 13 | | wmt21/AB | Official evaluation data for WMT21 with references A and B. | 14 | | wmt21 | Official evaluation data for WMT21. | 15 | | wmt20/robust/set1 | WMT20 robustness task, set 1 | 16 | | wmt20/robust/set2 | WMT20 robustness task, set 2 | 17 | | wmt20/robust/set3 | WMT20 robustness task, set 3 | 18 | | wmt20/tworefs | WMT20 news test sets with two references | 19 | | wmt20 | Official evaluation data for WMT20 | 20 | | mtnt2019 | Test set for the WMT 19 robustness shared task | 21 | | mtnt1.1/test | Test data for the Machine Translation of Noisy Text task: [URL](http://www.cs.cmu.edu/~pmichel1/mtnt/) | 22 | | mtnt1.1/valid | Validation data for the Machine Translation of Noisy Text task: [URL](http://www.cs.cmu.edu/~pmichel1/mtnt/) | 23 | | mtnt1.1/train | Training data for the Machine Translation of Noisy Text task: [URL](http://www.cs.cmu.edu/~pmichel1/mtnt/) | 24 | | wmt20/dev | Development data for tasks new to 2020. | 25 | | wmt19 | Official evaluation data. | 26 | | wmt19/dev | Development data for tasks new to 2019. | 27 | | wmt19/google/ar | Additional high-quality reference for WMT19/en-de. | 28 | | wmt19/google/arp | Additional paraphrase of wmt19/google/ar. | 29 | | wmt19/google/wmtp | Additional paraphrase of the official WMT19 reference. | 30 | | wmt19/google/hqr | Best human selected-reference between wmt19 and wmt19/google/ar. | 31 | | wmt19/google/hqp | Best human-selected reference between wmt19/google/arp and wmt19/google/wmtp. | 32 | | wmt19/google/hqall | Best human-selected reference among original official reference and the Google reference and paraphrases. | 33 | | wmt18 | Official evaluation data. | 34 | | wmt18/test-ts | Official evaluation sources with extra test sets interleaved. | 35 | | wmt18/dev | Development data (Estonian<>English). | 36 | | wmt17 | Official evaluation data. | 37 | | wmt17/B | Additional reference for EN-FI and FI-EN. | 38 | | wmt17/tworefs | Systems with two references. | 39 | | wmt17/improved | Improved zh-en and en-zh translations. | 40 | | wmt17/dev | Development sets released for new languages in 2017. | 41 | | wmt17/ms | Additional Chinese-English references from Microsoft Research. | 42 | | wmt16 | Official evaluation data. | 43 | | wmt16/B | Additional reference for EN-FI. | 44 | | wmt16/tworefs | EN-FI with two references. | 45 | | wmt16/dev | Development sets released for new languages in 2016. | 46 | | wmt15 | Official evaluation data. | 47 | | wmt14 | Official evaluation data. | 48 | | wmt14/full | Evaluation data released after official evaluation for further research. | 49 | | wmt13 | Official evaluation data. | 50 | | wmt12 | Official evaluation data. | 51 | | wmt11 | Official evaluation data. | 52 | | wmt10 | Official evaluation data. | 53 | | wmt09 | Official evaluation data. | 54 | | wmt08 | Official evaluation data. | 55 | | wmt08/nc | Official evaluation data (news commentary). | 56 | | wmt08/europarl | Official evaluation data (Europarl). | 57 | | iwslt17 | Official evaluation data for IWSLT. | 58 | | iwslt17/tst2016 | Development data for IWSLT 2017. | 59 | | iwslt17/tst2015 | Development data for IWSLT 2017. | 60 | | iwslt17/tst2014 | Development data for IWSLT 2017. | 61 | | iwslt17/tst2013 | Development data for IWSLT 2017. | 62 | | iwslt17/tst2012 | Development data for IWSLT 2017. | 63 | | iwslt17/tst2011 | Development data for IWSLT 2017. | 64 | | iwslt17/tst2010 | Development data for IWSLT 2017. | 65 | | iwslt17/dev2010 | Development data for IWSLT 2017. | 66 | | multi30k/2016 | 2016 flickr test set of Multi30k dataset | 67 | | multi30k/2017 | 2017 flickr test set of Multi30k dataset | 68 | | multi30k/2018 | 2018 flickr test set of Multi30k dataset. See [URL](https://competitions.codalab.org/competitions/19917) for evaluation. | 69 | -------------------------------------------------------------------------------- /sacrebleu/compat.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional 2 | 3 | from .metrics import BLEU, CHRF, TER, BLEUScore, CHRFScore, TERScore 4 | 5 | 6 | ###################################################################### 7 | # Backward compatibility functions for old style API access (< 1.4.11) 8 | ###################################################################### 9 | def corpus_bleu(hypotheses: Sequence[str], 10 | references: Sequence[Sequence[str]], 11 | smooth_method='exp', 12 | smooth_value=None, 13 | force=False, 14 | lowercase=False, 15 | tokenize=BLEU.TOKENIZER_DEFAULT, 16 | use_effective_order=False) -> BLEUScore: 17 | """Computes BLEU for a corpus against a single (or multiple) reference(s). 18 | This is the main CLI entry point for computing BLEU between a system output 19 | and a reference sentence. 20 | 21 | :param hypotheses: A sequence of hypothesis strings. 22 | :param references: A sequence of reference documents with document being 23 | defined as a sequence of reference strings. 24 | :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none') 25 | :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. 26 | :param force: Ignore data that looks already tokenized 27 | :param lowercase: Lowercase the data 28 | :param tokenize: The tokenizer to use 29 | :param use_effective_order: Don't take into account n-gram orders without any match. 30 | :return: a `BLEUScore` object 31 | """ 32 | metric = BLEU( 33 | lowercase=lowercase, force=force, tokenize=tokenize, 34 | smooth_method=smooth_method, smooth_value=smooth_value, 35 | effective_order=use_effective_order) 36 | 37 | return metric.corpus_score(hypotheses, references) 38 | 39 | 40 | def raw_corpus_bleu(hypotheses: Sequence[str], 41 | references: Sequence[Sequence[str]], 42 | smooth_value: Optional[float] = BLEU.SMOOTH_DEFAULTS['floor']) -> BLEUScore: 43 | """Computes BLEU for a corpus against a single (or multiple) reference(s). 44 | This convenience function assumes a particular set of arguments i.e. 45 | it disables tokenization and applies a `floor` smoothing with value `0.1`. 46 | 47 | This convenience call does not apply any tokenization at all, 48 | neither to the system output nor the reference. It just computes 49 | BLEU on the "raw corpus" (hence the name). 50 | 51 | :param hypotheses: A sequence of hypothesis strings. 52 | :param references: A sequence of reference documents with document being 53 | defined as a sequence of reference strings. 54 | :param smooth_value: The smoothing value for `floor`. If not given, the default of 0.1 is used. 55 | :return: Returns a `BLEUScore` object. 56 | 57 | """ 58 | return corpus_bleu( 59 | hypotheses, references, smooth_method='floor', 60 | smooth_value=smooth_value, force=True, tokenize='none', 61 | use_effective_order=True) 62 | 63 | 64 | def sentence_bleu(hypothesis: str, 65 | references: Sequence[str], 66 | smooth_method: str = 'exp', 67 | smooth_value: Optional[float] = None, 68 | lowercase: bool = False, 69 | tokenize=BLEU.TOKENIZER_DEFAULT, 70 | use_effective_order: bool = True) -> BLEUScore: 71 | """ 72 | Computes BLEU for a single sentence against a single (or multiple) reference(s). 73 | 74 | Disclaimer: Computing BLEU at the sentence level is not its intended use as 75 | BLEU is a corpus-level metric. 76 | 77 | :param hypothesis: A single hypothesis string. 78 | :param references: A sequence of reference strings. 79 | :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none') 80 | :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. 81 | :param lowercase: Lowercase the data 82 | :param tokenize: The tokenizer to use 83 | :param use_effective_order: Don't take into account n-gram orders without any match. 84 | :return: Returns a `BLEUScore` object. 85 | """ 86 | metric = BLEU( 87 | lowercase=lowercase, tokenize=tokenize, force=False, 88 | smooth_method=smooth_method, smooth_value=smooth_value, 89 | effective_order=use_effective_order) 90 | 91 | return metric.sentence_score(hypothesis, references) 92 | 93 | 94 | def corpus_chrf(hypotheses: Sequence[str], 95 | references: Sequence[Sequence[str]], 96 | char_order: int = CHRF.CHAR_ORDER, 97 | word_order: int = CHRF.WORD_ORDER, 98 | beta: int = CHRF.BETA, 99 | remove_whitespace: bool = True, 100 | eps_smoothing: bool = False) -> CHRFScore: 101 | """ 102 | Computes chrF for a corpus against a single (or multiple) reference(s). 103 | If `word_order` equals to 2, the metric is referred to as chrF++. 104 | 105 | :param hypotheses: A sequence of hypothesis strings. 106 | :param references: A sequence of reference documents with document being 107 | defined as a sequence of reference strings. 108 | :param char_order: Character n-gram order. 109 | :param word_order: Word n-gram order. If equals to 2, the metric is referred to as chrF++. 110 | :param beta: Determine the importance of recall w.r.t precision. 111 | :param eps_smoothing: If `True`, applies epsilon smoothing similar 112 | to reference chrF++.py, NLTK and Moses implementations. Otherwise, 113 | it takes into account effective match order similar to sacreBLEU < 2.0.0. 114 | :param remove_whitespace: If `True`, removes whitespaces prior to character n-gram extraction. 115 | :return: A `CHRFScore` object. 116 | """ 117 | metric = CHRF( 118 | char_order=char_order, 119 | word_order=word_order, 120 | beta=beta, 121 | whitespace=not remove_whitespace, 122 | eps_smoothing=eps_smoothing) 123 | return metric.corpus_score(hypotheses, references) 124 | 125 | 126 | def sentence_chrf(hypothesis: str, 127 | references: Sequence[str], 128 | char_order: int = CHRF.CHAR_ORDER, 129 | word_order: int = CHRF.WORD_ORDER, 130 | beta: int = CHRF.BETA, 131 | remove_whitespace: bool = True, 132 | eps_smoothing: bool = False) -> CHRFScore: 133 | """ 134 | Computes chrF for a single sentence against a single (or multiple) reference(s). 135 | If `word_order` equals to 2, the metric is referred to as chrF++. 136 | 137 | :param hypothesis: A single hypothesis string. 138 | :param references: A sequence of reference strings. 139 | :param char_order: Character n-gram order. 140 | :param word_order: Word n-gram order. If equals to 2, the metric is referred to as chrF++. 141 | :param beta: Determine the importance of recall w.r.t precision. 142 | :param eps_smoothing: If `True`, applies epsilon smoothing similar 143 | to reference chrF++.py, NLTK and Moses implementations. Otherwise, 144 | it takes into account effective match order similar to sacreBLEU < 2.0.0. 145 | :param remove_whitespace: If `True`, removes whitespaces prior to character n-gram extraction. 146 | :return: A `CHRFScore` object. 147 | """ 148 | metric = CHRF( 149 | char_order=char_order, 150 | word_order=word_order, 151 | beta=beta, 152 | whitespace=not remove_whitespace, 153 | eps_smoothing=eps_smoothing) 154 | return metric.sentence_score(hypothesis, references) 155 | 156 | 157 | def corpus_ter(hypotheses: Sequence[str], 158 | references: Sequence[Sequence[str]], 159 | normalized: bool = False, 160 | no_punct: bool = False, 161 | asian_support: bool = False, 162 | case_sensitive: bool = False) -> TERScore: 163 | """ 164 | Computes TER for a corpus against a single (or multiple) reference(s). 165 | 166 | :param hypotheses: A sequence of hypothesis strings. 167 | :param references: A sequence of reference documents with document being 168 | defined as a sequence of reference strings. 169 | :param normalized: Enable character normalization. 170 | :param no_punct: Remove punctuation. 171 | :param asian_support: Enable special treatment of Asian characters. 172 | :param case_sensitive: Enables case-sensitivity. 173 | :return: A `TERScore` object. 174 | """ 175 | metric = TER( 176 | normalized=normalized, 177 | no_punct=no_punct, 178 | asian_support=asian_support, 179 | case_sensitive=case_sensitive) 180 | return metric.corpus_score(hypotheses, references) 181 | 182 | 183 | def sentence_ter(hypothesis: str, 184 | references: Sequence[str], 185 | normalized: bool = False, 186 | no_punct: bool = False, 187 | asian_support: bool = False, 188 | case_sensitive: bool = False) -> TERScore: 189 | """ 190 | Computes TER for a single hypothesis against a single (or multiple) reference(s). 191 | 192 | :param hypothesis: A single hypothesis string. 193 | :param references: A sequence of reference strings. 194 | :param normalized: Enable character normalization. 195 | :param no_punct: Remove punctuation. 196 | :param asian_support: Enable special treatment of Asian characters. 197 | :param case_sensitive: Enable case-sensitivity. 198 | :return: A `TERScore` object. 199 | """ 200 | metric = TER( 201 | normalized=normalized, 202 | no_punct=no_punct, 203 | asian_support=asian_support, 204 | case_sensitive=case_sensitive) 205 | return metric.sentence_score(hypothesis, references) 206 | -------------------------------------------------------------------------------- /sacrebleu/metrics/chrf.py: -------------------------------------------------------------------------------- 1 | """The implementation of chrF (Popović 2015) and chrF++ (Popović 2017) metrics.""" 2 | 3 | from typing import List, Sequence, Optional, Dict 4 | from collections import Counter 5 | 6 | from ..utils import sum_of_lists 7 | from .base import Score, Signature, Metric 8 | from .helpers import extract_all_char_ngrams, extract_word_ngrams 9 | 10 | 11 | class CHRFSignature(Signature): 12 | """A convenience class to represent the reproducibility signature for chrF. 13 | 14 | :param args: key-value dictionary passed from the actual metric instance. 15 | """ 16 | def __init__(self, args: dict): 17 | """`CHRFSignature` initializer.""" 18 | super().__init__(args) 19 | self._abbr.update({ 20 | 'case': 'c', 21 | 'eff': 'e', 22 | 'nc': 'nc', 23 | 'nw': 'nw', 24 | 'space': 's', 25 | }) 26 | 27 | self.info.update({ 28 | 'case': 'lc' if args['lowercase'] else 'mixed', 29 | 'eff': 'yes' if not args['eps_smoothing'] else 'no', 30 | 'nc': args['char_order'], 31 | 'nw': args['word_order'], 32 | 'space': 'yes' if args['whitespace'] else 'no', 33 | }) 34 | 35 | 36 | class CHRFScore(Score): 37 | """A convenience class to represent chrF scores. 38 | 39 | :param score: The chrF (chrF++) score. 40 | :param char_order: The character n-gram order. 41 | :param word_order: The word n-gram order. If equals to 2, the metric is referred to as chrF++. 42 | :param beta: Determine the importance of recall w.r.t precision. 43 | """ 44 | def __init__(self, score: float, char_order: int, word_order: int, beta: int): 45 | """`CHRFScore` initializer.""" 46 | self.beta = beta 47 | self.char_order = char_order 48 | self.word_order = word_order 49 | 50 | # Add + signs to denote chrF+ variant 51 | name = f'chrF{self.beta}' + '+' * self.word_order 52 | 53 | super().__init__(name, score) 54 | 55 | 56 | class CHRF(Metric): 57 | """Computes the chrF(++) metric given hypotheses and references. 58 | 59 | :param char_order: Character n-gram order. 60 | :param word_order: Word n-gram order. If equals to 2, the metric is referred to as chrF++. 61 | :param beta: Determine the importance of recall w.r.t precision. 62 | :param lowercase: Enable case-insensitivity. 63 | :param whitespace: If `True`, include whitespaces when extracting character n-grams. 64 | :param eps_smoothing: If `True`, applies epsilon smoothing similar 65 | to reference chrF++.py, NLTK and Moses implementations. Otherwise, 66 | it takes into account effective match order similar to sacreBLEU < 2.0.0. 67 | :param references: A sequence of reference documents with document being 68 | defined as a sequence of reference strings. If given, the reference n-grams 69 | will be pre-computed and cached for faster re-computation across many systems. 70 | """ 71 | 72 | # Maximum character n-gram order to take into account 73 | CHAR_ORDER = 6 74 | 75 | # chrF+ additionally takes into account some of the word n-grams 76 | WORD_ORDER = 0 77 | 78 | # Defaults to 2 (per http://www.aclweb.org/anthology/W16-2341) 79 | BETA = 2 80 | 81 | # Cache string.punctuation for chrF+' punctuation stripper 82 | _PUNCTS = set('!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~') 83 | 84 | _SIGNATURE_TYPE = CHRFSignature 85 | 86 | def __init__(self, char_order: int = CHAR_ORDER, 87 | word_order: int = WORD_ORDER, 88 | beta: int = BETA, 89 | lowercase: bool = False, 90 | whitespace: bool = False, 91 | eps_smoothing: bool = False, 92 | references: Optional[Sequence[Sequence[str]]] = None): 93 | """`CHRF` initializer.""" 94 | super().__init__() 95 | 96 | self.beta = beta 97 | self.char_order = char_order 98 | self.word_order = word_order 99 | self.order = self.char_order + self.word_order 100 | self.lowercase = lowercase 101 | self.whitespace = whitespace 102 | self.eps_smoothing = eps_smoothing 103 | 104 | if references is not None: 105 | # Pre-compute reference ngrams 106 | self._ref_cache = self._cache_references(references) 107 | 108 | @staticmethod 109 | def _get_match_statistics(hyp_ngrams: Counter, ref_ngrams: Counter) -> List[int]: 110 | """Computes the match statistics between hypothesis and reference n-grams. 111 | 112 | :param hyp_ngrams: A `Counter` holding hypothesis n-grams. 113 | :param ref_ngrams: A `Counter` holding reference n-grams. 114 | :return: A list of three numbers denoting hypothesis n-gram count, 115 | reference n-gram count and the intersection count. 116 | """ 117 | # Counter's internal intersection is not that fast, count manually 118 | match_count, hyp_count = 0, 0 119 | for ng, count in hyp_ngrams.items(): 120 | hyp_count += count 121 | if ng in ref_ngrams: 122 | match_count += min(count, ref_ngrams[ng]) 123 | 124 | return [ 125 | # Don't count hits if no reference exists for that n-gram 126 | hyp_count if ref_ngrams else 0, 127 | sum(ref_ngrams.values()), 128 | match_count, 129 | ] 130 | 131 | def _remove_punctuation(self, sent: str) -> List[str]: 132 | """Separates out punctuations from beginning and end of words for chrF. 133 | Adapted from https://github.com/m-popovic/chrF 134 | 135 | :param sent: A string. 136 | :return: A list of words. 137 | """ 138 | tokenized = [] 139 | for w in sent.split(): 140 | if len(w) == 1: 141 | tokenized.append(w) 142 | else: 143 | # NOTE: This splits '(hi)' to '(hi' and ')' (issue #124) 144 | if w[-1] in self._PUNCTS: 145 | tokenized += [w[:-1], w[-1]] 146 | elif w[0] in self._PUNCTS: 147 | tokenized += [w[0], w[1:]] 148 | else: 149 | tokenized.append(w) 150 | return tokenized 151 | 152 | def _preprocess_segment(self, sent: str) -> str: 153 | """Given a sentence, apply optional lowercasing. 154 | 155 | :param sent: The input sentence string. 156 | :return: The pre-processed output string. 157 | """ 158 | return sent.lower() if self.lowercase else sent 159 | 160 | def _compute_f_score(self, statistics: List[int]) -> float: 161 | """Compute the chrF score given the n-gram match statistics. 162 | 163 | :param statistics: A flattened list of 3 * (`char_order` + `word_order`) 164 | elements giving the [hyp, ref, match] counts for each order. 165 | :return: The final f_beta score between [0, 100]. 166 | """ 167 | eps = 1e-16 168 | score = 0.0 169 | effective_order = 0 170 | factor = self.beta ** 2 171 | avg_prec, avg_rec = 0.0, 0.0 172 | 173 | for i in range(self.order): 174 | n_hyp, n_ref, n_match = statistics[3 * i: 3 * i + 3] 175 | 176 | # chrF++.py style EPS smoothing (also used by Moses and NLTK) 177 | prec = n_match / n_hyp if n_hyp > 0 else eps 178 | rec = n_match / n_ref if n_ref > 0 else eps 179 | 180 | denom = factor * prec + rec 181 | score += ((1 + factor) * prec * rec / denom) if denom > 0 else eps 182 | 183 | # sacreBLEU <2.0.0 style effective order smoothing 184 | if n_hyp > 0 and n_ref > 0: 185 | avg_prec += prec 186 | avg_rec += rec 187 | effective_order += 1 188 | 189 | if self.eps_smoothing: 190 | return 100 * score / self.order 191 | 192 | if effective_order == 0: 193 | avg_prec = avg_rec = 0.0 194 | else: 195 | avg_prec /= effective_order 196 | avg_rec /= effective_order 197 | 198 | if avg_prec + avg_rec: 199 | score = (1 + factor) * avg_prec * avg_rec 200 | score /= ((factor * avg_prec) + avg_rec) 201 | return 100 * score 202 | else: 203 | return 0.0 204 | 205 | def _compute_score_from_stats(self, stats: List[int]) -> CHRFScore: 206 | """Computes the final score from already aggregated statistics. 207 | 208 | :param stats: A list or numpy array of segment-level statistics. 209 | :return: A `CHRFScore` object. 210 | """ 211 | return CHRFScore( 212 | self._compute_f_score(stats), self.char_order, 213 | self.word_order, self.beta) 214 | 215 | def _aggregate_and_compute(self, stats: List[List[int]]) -> CHRFScore: 216 | """Computes the final score given the pre-computed corpus statistics. 217 | 218 | :param stats: A list of segment-level statistics 219 | :return: A `CHRFScore` object. 220 | """ 221 | return self._compute_score_from_stats(sum_of_lists(stats)) 222 | 223 | def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, List[List[Counter]]]: 224 | """Given a list of reference segments, extract the character and word n-grams. 225 | 226 | :param refs: A sequence of reference segments. 227 | :return: A list where each element contains n-grams per reference segment. 228 | """ 229 | ngrams = [] 230 | 231 | for ref in refs: 232 | # extract character n-grams 233 | stats = extract_all_char_ngrams(ref, self.char_order, self.whitespace) 234 | 235 | # Check chrF+ mode 236 | if self.word_order > 0: 237 | ref_words = self._remove_punctuation(ref) 238 | 239 | for n in range(self.word_order): 240 | stats.append(extract_word_ngrams(ref_words, n + 1)) 241 | 242 | ngrams.append(stats) 243 | 244 | return {'ref_ngrams': ngrams} 245 | 246 | def _compute_segment_statistics( 247 | self, hypothesis: str, ref_kwargs: Dict) -> List[int]: 248 | """Given a (pre-processed) hypothesis sentence and already computed 249 | reference n-grams, returns the best match statistics across the 250 | references. 251 | 252 | :param hypothesis: Hypothesis sentence. 253 | :param ref_kwargs: A dictionary with key `ref_ngrams` which is a list 254 | where each sublist contains n-gram counters for a particular reference sentence. 255 | :return: A list of integers where each triplet denotes [hyp, ref, match] 256 | statistics. 257 | """ 258 | best_stats = [] 259 | best_f_score = -1.0 260 | 261 | # extract character n-grams 262 | all_hyp_ngrams = extract_all_char_ngrams( 263 | hypothesis, self.char_order, self.whitespace) 264 | 265 | # Check chrF+ mode to see if we'll add word n-grams as well 266 | if self.word_order > 0: 267 | # Primitive tokenization: separate out punctuations 268 | hwords = self._remove_punctuation(hypothesis) 269 | _range = range(1, self.word_order + 1) 270 | all_hyp_ngrams.extend([extract_word_ngrams(hwords, n) for n in _range]) 271 | 272 | # Iterate over multiple references, pick the one with best F score 273 | for _ref_ngrams in ref_kwargs['ref_ngrams']: 274 | stats = [] 275 | # Traverse all orders 276 | for h, r in zip(all_hyp_ngrams, _ref_ngrams): 277 | stats.extend(self._get_match_statistics(h, r)) 278 | f_score = self._compute_f_score(stats) 279 | 280 | if f_score > best_f_score: 281 | best_f_score = f_score 282 | best_stats = stats 283 | 284 | return best_stats 285 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /sacrebleu/metrics/lib_ter.py: -------------------------------------------------------------------------------- 1 | """This module implements various utility functions for the TER metric.""" 2 | 3 | # Copyright 2020 Memsource 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import math 19 | from typing import List, Tuple, Dict 20 | 21 | 22 | _COST_INS = 1 23 | _COST_DEL = 1 24 | _COST_SUB = 1 25 | 26 | # Tercom-inspired limits 27 | _MAX_SHIFT_SIZE = 10 28 | _MAX_SHIFT_DIST = 50 29 | _BEAM_WIDTH = 25 30 | 31 | # Our own limits 32 | _MAX_CACHE_SIZE = 10000 33 | _MAX_SHIFT_CANDIDATES = 1000 34 | _INT_INFINITY = int(1e16) 35 | 36 | _OP_INS = 'i' 37 | _OP_DEL = 'd' 38 | _OP_NOP = ' ' 39 | _OP_SUB = 's' 40 | _OP_UNDEF = 'x' 41 | 42 | _FLIP_OPS = str.maketrans(_OP_INS + _OP_DEL, _OP_DEL + _OP_INS) 43 | 44 | 45 | def translation_edit_rate(words_hyp: List[str], words_ref: List[str]) -> Tuple[int, int]: 46 | """Calculate the translation edit rate. 47 | 48 | :param words_hyp: Tokenized translation hypothesis. 49 | :param words_ref: Tokenized reference translation. 50 | :return: tuple (number of edits, length) 51 | """ 52 | n_words_ref = len(words_ref) 53 | n_words_hyp = len(words_hyp) 54 | if n_words_ref == 0: 55 | # FIXME: This trace here is not used? 56 | trace = _OP_DEL * n_words_hyp 57 | # special treatment of empty refs 58 | return n_words_hyp, 0 59 | 60 | cached_ed = BeamEditDistance(words_ref) 61 | shifts = 0 62 | 63 | input_words = words_hyp 64 | checked_candidates = 0 65 | while True: 66 | # do shifts until they stop reducing the edit distance 67 | delta, new_input_words, checked_candidates = _shift( 68 | input_words, words_ref, cached_ed, checked_candidates) 69 | 70 | if checked_candidates >= _MAX_SHIFT_CANDIDATES: 71 | break 72 | 73 | if delta <= 0: 74 | break 75 | shifts += 1 76 | input_words = new_input_words 77 | 78 | edit_distance, trace = cached_ed(input_words) 79 | total_edits = shifts + edit_distance 80 | 81 | return total_edits, n_words_ref 82 | 83 | 84 | def _shift(words_h: List[str], words_r: List[str], cached_ed, 85 | checked_candidates: int) -> Tuple[int, List[str], int]: 86 | """Attempt to shift words in hypothesis to match reference. 87 | 88 | Returns the shift that reduces the edit distance the most. 89 | 90 | Note that the filtering of possible shifts and shift selection are heavily 91 | based on somewhat arbitrary heuristics. The code here follows as closely 92 | as possible the logic in Tercom, not always justifying the particular design 93 | choices. 94 | 95 | :param words_h: Hypothesis. 96 | :param words_r: Reference. 97 | :param cached_ed: Cached edit distance. 98 | :param checked_candidates: Number of shift candidates that were already 99 | evaluated. 100 | :return: (score, shifted_words, checked_candidates). Best shift and updated 101 | number of evaluated shift candidates. 102 | """ 103 | pre_score, inv_trace = cached_ed(words_h) 104 | 105 | # to get alignment, we pretend we are rewriting reference into hypothesis, 106 | # so we need to flip the trace of edit operations 107 | trace = _flip_trace(inv_trace) 108 | align, ref_err, hyp_err = trace_to_alignment(trace) 109 | 110 | best = None 111 | 112 | for start_h, start_r, length in _find_shifted_pairs(words_h, words_r): 113 | # don't do the shift unless both the hypothesis was wrong and the 114 | # reference doesn't match hypothesis at the target position 115 | if sum(hyp_err[start_h: start_h + length]) == 0: 116 | continue 117 | 118 | if sum(ref_err[start_r: start_r + length]) == 0: 119 | continue 120 | 121 | # don't try to shift within the subsequence 122 | if start_h <= align[start_r] < start_h + length: 123 | continue 124 | 125 | prev_idx = -1 126 | for offset in range(-1, length): 127 | if start_r + offset == -1: 128 | idx = 0 # insert before the beginning 129 | elif start_r + offset in align: 130 | # Unlike Tercom which inserts *after* the index, we insert 131 | # *before* the index. 132 | idx = align[start_r + offset] + 1 133 | else: 134 | break # offset is out of bounds => aims past reference 135 | 136 | if idx == prev_idx: 137 | continue # skip idx if already tried 138 | 139 | prev_idx = idx 140 | 141 | shifted_words = _perform_shift(words_h, start_h, length, idx) 142 | assert(len(shifted_words) == len(words_h)) 143 | 144 | # Elements of the tuple are designed to replicate Tercom ranking 145 | # of shifts: 146 | candidate = ( 147 | pre_score - cached_ed(shifted_words)[0], # highest score first 148 | length, # then, longest match first 149 | -start_h, # then, earliest match first 150 | -idx, # then, earliest target position first 151 | shifted_words, 152 | ) 153 | 154 | checked_candidates += 1 155 | 156 | if not best or candidate > best: 157 | best = candidate 158 | 159 | if checked_candidates >= _MAX_SHIFT_CANDIDATES: 160 | break 161 | 162 | if not best: 163 | return 0, words_h, checked_candidates 164 | else: 165 | best_score, _, _, _, shifted_words = best 166 | return best_score, shifted_words, checked_candidates 167 | 168 | 169 | def _perform_shift(words: List[str], start: int, length: int, target: int) -> List[str]: 170 | """Perform a shift in `words` from `start` to `target`. 171 | 172 | :param words: Words to shift. 173 | :param start: Where from. 174 | :param length: How many words. 175 | :param target: Where to. 176 | :return: Shifted words. 177 | """ 178 | if target < start: 179 | # shift before previous position 180 | return words[:target] + words[start: start + length] \ 181 | + words[target: start] + words[start + length:] 182 | elif target > start + length: 183 | # shift after previous position 184 | return words[:start] + words[start + length: target] \ 185 | + words[start: start + length] + words[target:] 186 | else: 187 | # shift within the shifted string 188 | return words[:start] + words[start + length: length + target] \ 189 | + words[start: start + length] + words[length + target:] 190 | 191 | 192 | def _find_shifted_pairs(words_h: List[str], words_r: List[str]): 193 | """Find matching word sub-sequences in two lists of words. 194 | 195 | Ignores sub-sequences starting at the same position. 196 | 197 | :param words_h: First word list. 198 | :param words_r: Second word list. 199 | :return: Yields tuples of (h_start, r_start, length) such that: 200 | words_h[h_start:h_start+length] = words_r[r_start:r_start+length] 201 | """ 202 | n_words_h = len(words_h) 203 | n_words_r = len(words_r) 204 | for start_h in range(n_words_h): 205 | for start_r in range(n_words_r): 206 | # this is slightly different from what tercom does but this should 207 | # really only kick in in degenerate cases 208 | if abs(start_r - start_h) > _MAX_SHIFT_DIST: 209 | continue 210 | 211 | length = 0 212 | while words_h[start_h + length] == words_r[start_r + length] and length < _MAX_SHIFT_SIZE: 213 | length += 1 214 | 215 | yield start_h, start_r, length 216 | 217 | # If one sequence is consumed, stop processing 218 | if n_words_h == start_h + length or n_words_r == start_r + length: 219 | break 220 | 221 | 222 | def _flip_trace(trace): 223 | """Flip the trace of edit operations. 224 | 225 | Instead of rewriting a->b, get a recipe for rewriting b->a. 226 | 227 | Simply flips insertions and deletions. 228 | """ 229 | return trace.translate(_FLIP_OPS) 230 | 231 | 232 | def trace_to_alignment(trace: str) -> Tuple[Dict, List, List]: 233 | """Transform trace of edit operations into an alignment of the sequences. 234 | 235 | :param trace: Trace of edit operations (' '=no change or 's'/'i'/'d'). 236 | :return: Alignment, error positions in reference, error positions in hypothesis. 237 | """ 238 | pos_hyp = -1 239 | pos_ref = -1 240 | hyp_err = [] 241 | ref_err = [] 242 | align = {} 243 | 244 | # we are rewriting a into b 245 | for op in trace: 246 | if op == _OP_NOP: 247 | pos_hyp += 1 248 | pos_ref += 1 249 | align[pos_ref] = pos_hyp 250 | hyp_err.append(0) 251 | ref_err.append(0) 252 | elif op == _OP_SUB: 253 | pos_hyp += 1 254 | pos_ref += 1 255 | align[pos_ref] = pos_hyp 256 | hyp_err.append(1) 257 | ref_err.append(1) 258 | elif op == _OP_INS: 259 | pos_hyp += 1 260 | hyp_err.append(1) 261 | elif op == _OP_DEL: 262 | pos_ref += 1 263 | align[pos_ref] = pos_hyp 264 | ref_err.append(1) 265 | else: 266 | raise Exception(f"unknown operation {op!r}") 267 | 268 | return align, ref_err, hyp_err 269 | 270 | 271 | class BeamEditDistance: 272 | """Edit distance with several features required for TER calculation. 273 | 274 | * internal cache 275 | * "beam" search 276 | * tracking of edit operations 277 | 278 | The internal self._cache works like this: 279 | 280 | Keys are words of the hypothesis. Values are tuples (next_node, row) where: 281 | 282 | * next_node is the cache for the next word in the sequence 283 | * row is the stored row of the edit distance matrix 284 | 285 | Effectively, caching allows to skip several rows in the edit distance 286 | matrix calculation and instead, to initialize the computation with the last 287 | matching matrix row. 288 | 289 | Beam search, as implemented here, only explores a fixed-size sub-row of 290 | candidates around the matrix diagonal (more precisely, it's a 291 | "pseudo"-diagonal since we take the ratio of sequence lengths into account). 292 | 293 | Tracking allows to reconstruct the optimal sequence of edit operations. 294 | 295 | :param words_ref: A list of reference tokens. 296 | """ 297 | def __init__(self, words_ref: List[str]): 298 | """`BeamEditDistance` initializer.""" 299 | self._words_ref = words_ref 300 | self._n_words_ref = len(self._words_ref) 301 | 302 | # first row corresponds to insertion operations of the reference, 303 | # so we do 1 edit operation per reference word 304 | self._initial_row = [(i * _COST_INS, _OP_INS) 305 | for i in range(self._n_words_ref + 1)] 306 | 307 | self._cache = {} # type: Dict[str, Tuple] 308 | self._cache_size = 0 309 | 310 | # Precomputed empty matrix row. Contains infinities so that beam search 311 | # avoids using the uninitialized cells. 312 | self._empty_row = [(_INT_INFINITY, _OP_UNDEF)] * (self._n_words_ref + 1) 313 | 314 | def __call__(self, words_hyp: List[str]) -> Tuple[int, str]: 315 | """Calculate edit distance between self._words_ref and the hypothesis. 316 | 317 | Uses cache to skip some of the computation. 318 | 319 | :param words_hyp: Words in translation hypothesis. 320 | :return: Edit distance score. 321 | """ 322 | 323 | # skip initial words in the hypothesis for which we already know the 324 | # edit distance 325 | start_position, dist = self._find_cache(words_hyp) 326 | 327 | # calculate the rest of the edit distance matrix 328 | edit_distance, newly_created_matrix, trace = self._edit_distance( 329 | words_hyp, start_position, dist) 330 | 331 | # update our cache with the newly calculated rows 332 | self._add_cache(words_hyp, newly_created_matrix) 333 | 334 | return edit_distance, trace 335 | 336 | def _edit_distance(self, words_h: List[str], start_h: int, 337 | cache: List[List[Tuple[int, str]]]) -> Tuple[int, List, str]: 338 | """Actual edit distance calculation. 339 | 340 | Can be initialized with the last cached row and a start position in 341 | the hypothesis that it corresponds to. 342 | 343 | :param words_h: Words in translation hypothesis. 344 | :param start_h: Position from which to start the calculation. 345 | (This is zero if no cache match was found.) 346 | :param cache: Precomputed rows corresponding to edit distance matrix 347 | before `start_h`. 348 | :return: Edit distance value, newly computed rows to update the 349 | cache, trace. 350 | """ 351 | 352 | n_words_h = len(words_h) 353 | 354 | # initialize the rest of the matrix with infinite edit distances 355 | rest_empty = [list(self._empty_row) 356 | for _ in range(n_words_h - start_h)] 357 | 358 | dist = cache + rest_empty 359 | 360 | assert len(dist) == n_words_h + 1 361 | 362 | length_ratio = self._n_words_ref / n_words_h if words_h else 1 363 | 364 | # in some crazy sentences, the difference in length is so large that 365 | # we may end up with zero overlap with previous row 366 | if _BEAM_WIDTH < length_ratio / 2: 367 | beam_width = math.ceil(length_ratio / 2 + _BEAM_WIDTH) 368 | else: 369 | beam_width = _BEAM_WIDTH 370 | 371 | # calculate the Levenshtein distance 372 | for i in range(start_h + 1, n_words_h + 1): 373 | pseudo_diag = math.floor(i * length_ratio) 374 | min_j = max(0, pseudo_diag - beam_width) 375 | max_j = min(self._n_words_ref + 1, pseudo_diag + beam_width) 376 | 377 | if i == n_words_h: 378 | max_j = self._n_words_ref + 1 379 | 380 | for j in range(min_j, max_j): 381 | if j == 0: 382 | dist[i][j] = (dist[i - 1][j][0] + _COST_DEL, _OP_DEL) 383 | else: 384 | if words_h[i - 1] == self._words_ref[j - 1]: 385 | cost_sub = 0 386 | op_sub = _OP_NOP 387 | else: 388 | cost_sub = _COST_SUB 389 | op_sub = _OP_SUB 390 | 391 | # Tercom prefers no-op/sub, then insertion, then deletion. 392 | # But since we flip the trace and compute the alignment from 393 | # the inverse, we need to swap order of insertion and 394 | # deletion in the preference. 395 | ops = ( 396 | (dist[i - 1][j - 1][0] + cost_sub, op_sub), 397 | (dist[i - 1][j][0] + _COST_DEL, _OP_DEL), 398 | (dist[i][j - 1][0] + _COST_INS, _OP_INS), 399 | ) 400 | 401 | for op_cost, op_name in ops: 402 | if dist[i][j][0] > op_cost: 403 | dist[i][j] = op_cost, op_name 404 | 405 | # get the trace 406 | trace = "" 407 | i = n_words_h 408 | j = self._n_words_ref 409 | 410 | while i > 0 or j > 0: 411 | op = dist[i][j][1] 412 | trace = op + trace 413 | if op in (_OP_SUB, _OP_NOP): 414 | i -= 1 415 | j -= 1 416 | elif op == _OP_INS: 417 | j -= 1 418 | elif op == _OP_DEL: 419 | i -= 1 420 | else: 421 | raise Exception(f"unknown operation {op!r}") 422 | 423 | return dist[-1][-1][0], dist[len(cache):], trace 424 | 425 | def _add_cache(self, words_hyp: List[str], mat: List[List[Tuple]]): 426 | """Add newly computed rows to cache. 427 | 428 | Since edit distance is only calculated on the hypothesis suffix that 429 | was not in cache, the number of rows in `mat` may be shorter than 430 | hypothesis length. In that case, we skip over these initial words. 431 | 432 | :param words_hyp: Hypothesis words. 433 | :param mat: Edit distance matrix rows for each position. 434 | """ 435 | if self._cache_size >= _MAX_CACHE_SIZE: 436 | return 437 | 438 | node = self._cache 439 | 440 | n_mat = len(mat) 441 | 442 | # how many initial words to skip 443 | skip_num = len(words_hyp) - n_mat 444 | 445 | # jump through the cache to the current position 446 | for i in range(skip_num): 447 | node = node[words_hyp[i]][0] 448 | 449 | assert len(words_hyp[skip_num:]) == n_mat 450 | 451 | # update cache with newly computed rows 452 | for word, row in zip(words_hyp[skip_num:], mat): 453 | if word not in node: 454 | node[word] = ({}, tuple(row)) 455 | self._cache_size += 1 456 | value = node[word] 457 | node = value[0] 458 | 459 | def _find_cache(self, words_hyp: List[str]) -> Tuple[int, List[List]]: 460 | """Find the already computed rows of the edit distance matrix in cache. 461 | 462 | Returns a partially computed edit distance matrix. 463 | 464 | :param words_hyp: Translation hypothesis. 465 | :return: Tuple (start position, dist). 466 | """ 467 | node = self._cache 468 | start_position = 0 469 | dist = [self._initial_row] 470 | for word in words_hyp: 471 | if word in node: 472 | start_position += 1 473 | node, row = node[word] 474 | dist.append(row) 475 | else: 476 | break 477 | 478 | return start_position, dist 479 | -------------------------------------------------------------------------------- /sacrebleu/metrics/base.py: -------------------------------------------------------------------------------- 1 | """The base `Score`, `Metric` and `Signature` classes to derive from. 2 | 3 | `Metric` is an abstract class that enforces the implementation of a set 4 | of abstract methods. This way, a correctly implemented metric will work 5 | seamlessly with the rest of the codebase. 6 | """ 7 | 8 | import json 9 | import logging 10 | import statistics 11 | from abc import ABCMeta, abstractmethod 12 | from typing import Any, Dict, List, Optional, Sequence 13 | 14 | from ..version import __version__ 15 | 16 | sacrelogger = logging.getLogger("sacrebleu") 17 | 18 | 19 | class Score: 20 | """A base score class to derive from. 21 | 22 | :param name: The name of the underlying metric. 23 | :param score: A floating point number for the final metric. 24 | """ 25 | 26 | def __init__(self, name: str, score: float): 27 | """`Score` initializer.""" 28 | self.name = name 29 | self.score = score 30 | 31 | # Statistical test related fields 32 | self._mean = -1.0 33 | self._ci = -1.0 34 | 35 | # More info can be added right after the score 36 | self._verbose = "" 37 | 38 | def format( 39 | self, 40 | width: int = 2, 41 | score_only: bool = False, 42 | signature: str = "", 43 | is_json: bool = False, 44 | ) -> str: 45 | """Returns a pretty representation of the score. 46 | :param width: Floating point decimal precision width. 47 | :param score_only: If `True`, and the format is not `json`, 48 | returns a single score string. 49 | :param signature: A string representation of the given `Signature` 50 | instance. 51 | :param is_json: If `True`, will output the score in JSON string. 52 | :return: A plain or JSON-formatted string representation. 53 | """ 54 | d = { 55 | "name": self.name, 56 | "score": float(f"{self.score:.{width}f}"), 57 | "signature": signature, 58 | } 59 | 60 | sc = f"{self.score:.{width}f}" 61 | 62 | if self._mean > 0: 63 | confidence_mean = f"{self._mean:.{width}f}" 64 | confidence_var = f"{self._ci:.{width}f}" 65 | confidence_str = f"μ = {confidence_mean} ± {confidence_var}" 66 | 67 | sc += f" ({confidence_str})" 68 | if is_json: 69 | d["confidence_mean"] = float(confidence_mean) 70 | d["confidence_var"] = float(confidence_var) 71 | d["confidence"] = confidence_str 72 | 73 | # Construct full score line 74 | full_score = f"{self.name}|{signature}" if signature else self.name 75 | full_score = f"{full_score} = {sc}" 76 | if self._verbose: 77 | full_score += f" {self._verbose}" 78 | d["verbose_score"] = self._verbose 79 | 80 | if score_only: 81 | return sc 82 | 83 | if is_json: 84 | for param in signature.split("|"): 85 | key, value = param.split(":") 86 | d[key] = value 87 | return json.dumps(d, indent=1, ensure_ascii=False) 88 | 89 | return full_score 90 | 91 | def estimate_ci(self, scores: List["Score"]): 92 | """Takes a list of scores and stores mean, stdev and 95% confidence 93 | interval around the mean. 94 | 95 | :param scores: A list of `Score` objects obtained from bootstrap 96 | resampling for example. 97 | """ 98 | # Sort the scores 99 | raw_scores = sorted([x.score for x in scores]) 100 | n = len(raw_scores) 101 | 102 | # Get CI bounds (95%, i.e. 1/40 from left) 103 | lower_idx = n // 40 104 | upper_idx = n - lower_idx - 1 105 | lower, upper = raw_scores[lower_idx], raw_scores[upper_idx] 106 | self._ci = 0.5 * (upper - lower) 107 | self._mean = statistics.mean(raw_scores) 108 | 109 | def __repr__(self): 110 | """Returns a human readable score string.""" 111 | return self.format() 112 | 113 | 114 | class Signature: 115 | """A convenience class to represent sacreBLEU reproducibility signatures. 116 | 117 | :param args: key-value dictionary passed from the actual metric instance. 118 | """ 119 | 120 | def __init__(self, args: dict): 121 | """`Signature` initializer.""" 122 | # Global items that are shared across all metrics 123 | self._abbr = { 124 | "version": "v", 125 | "nrefs": "#", 126 | "test": "t", 127 | "lang": "l", 128 | "subset": "S", 129 | "origlang": "o", 130 | "bs": "bs", # Bootstrap resampling trials 131 | "ar": "ar", # Approximate randomization trials 132 | "seed": "rs", # RNG's seed 133 | } 134 | 135 | if "num_refs" not in args: 136 | raise ValueError( 137 | "Number of references unknown, please evaluate the metric first." 138 | ) 139 | 140 | num_refs = args["num_refs"] 141 | if num_refs == -1: 142 | # Detect variable number of refs 143 | num_refs = "var" 144 | 145 | # Global items that are shared across all metrics 146 | # None's will be ignored 147 | self.info = { 148 | "version": __version__, 149 | "nrefs": num_refs, 150 | "bs": args.get("n_bootstrap", None), 151 | "ar": None, 152 | "seed": args.get("seed", None), 153 | "test": args.get("test_set", None), 154 | "lang": args.get("langpair", None), 155 | "origlang": args.get("origlang", None), 156 | "subset": args.get("subset", None), 157 | } 158 | 159 | def format(self, short: bool = False) -> str: 160 | """Returns a string representation of the signature. 161 | 162 | :param short: If True, shortened signature is produced. 163 | :return: A string representation of the signature. 164 | """ 165 | pairs = [] 166 | keys = list(self.info.keys()) 167 | # keep version always at end 168 | keys.remove("version") 169 | for name in keys + ["version"]: 170 | value = self.info[name] 171 | if value is not None: 172 | if isinstance(value, bool): 173 | # Replace True/False with yes/no 174 | value = "yes" if value else "no" 175 | final_name = self._abbr[name] if short else name 176 | pairs.append(f"{final_name}:{value}") 177 | 178 | return "|".join(pairs) 179 | 180 | def update(self, key: str, value: Any): 181 | """Add a new item or update an existing one. 182 | 183 | :param key: The key to use in the dictionary. 184 | :param value: The associated value for the `key`. 185 | """ 186 | self.info[key] = value 187 | 188 | def __str__(self): 189 | """Returns a human-readable signature string.""" 190 | return self.format() 191 | 192 | def __repr__(self): 193 | """Returns a human-readable signature string.""" 194 | return self.format() 195 | 196 | 197 | class Metric(metaclass=ABCMeta): 198 | """A base class for all metrics that ensures the implementation of some 199 | methods. Much of the common functionality is moved to this base class 200 | from other metrics.""" 201 | 202 | # Each metric should define its Signature class' name here 203 | _SIGNATURE_TYPE = Signature 204 | 205 | def __init__(self): 206 | """`Metric` initializer.""" 207 | # The pre-computed reference cache 208 | self._ref_cache = None 209 | 210 | # only useful for BLEU tokenized warnings. Set to True so that 211 | # warnings are not issued for other metrics. 212 | self._force = True 213 | 214 | # Will be used by the signature when bootstrap resampling 215 | self.n_bootstrap = None 216 | self.seed = None 217 | 218 | def _check_sentence_score_args(self, hyp: str, refs: Sequence[str]): 219 | """Performs sanity checks on `sentence_score` method's arguments. 220 | 221 | :param hyp: A single hypothesis string. 222 | :param refs: A sequence of reference strings. 223 | """ 224 | prefix = self.__class__.__name__ 225 | err_msg = None 226 | 227 | if not isinstance(hyp, str): 228 | err_msg = "The argument `hyp` should be a string." 229 | elif isinstance(refs, str) or not isinstance(refs, Sequence): 230 | err_msg = "The argument `refs` should be a sequence of strings." 231 | elif not isinstance(refs[0], str) and refs[0] is not None: 232 | err_msg = "Each element of `refs` should be a string." 233 | 234 | if err_msg: 235 | raise TypeError(f"{prefix}: {err_msg}") 236 | 237 | def _check_corpus_score_args( 238 | self, hyps: Sequence[str], refs: Optional[Sequence[Sequence[str]]] 239 | ): 240 | """Performs sanity checks on `corpus_score` method's arguments. 241 | 242 | :param hypses: A sequence of hypothesis strings. 243 | :param refs: A sequence of reference documents with document being 244 | defined as a sequence of reference strings. If `None`, cached references 245 | will be used. 246 | """ 247 | 248 | prefix = self.__class__.__name__ 249 | err_msg = None 250 | 251 | if not isinstance(hyps, Sequence): 252 | err_msg = "`hyps` should be a sequence of strings." 253 | elif not isinstance(hyps[0], str): 254 | err_msg = "Each element of `hyps` should be a string." 255 | elif any(line is None for line in hyps): 256 | err_msg = "Undefined line in hypotheses stream!" 257 | 258 | if refs is not None: 259 | if not isinstance(refs, Sequence): 260 | err_msg = "`refs` should be a sequence of sequence of strings." 261 | elif not isinstance(refs[0], Sequence): 262 | err_msg = "Each element of `refs` should be a sequence of strings." 263 | elif not isinstance(refs[0][0], str) and refs[0][0] is not None: 264 | err_msg = "`refs` should be a sequence of sequence of strings." 265 | 266 | if err_msg: 267 | raise TypeError(f"{prefix}: {err_msg}") 268 | 269 | @abstractmethod 270 | def _aggregate_and_compute(self, stats: List[List[Any]]) -> Any: 271 | """Computes the final score given the pre-computed match statistics. 272 | 273 | :param stats: A list of segment-level statistics. 274 | :return: A `Score` instance. 275 | """ 276 | pass 277 | 278 | @abstractmethod 279 | def _compute_score_from_stats(self, stats: List[Any]) -> Any: 280 | """Computes the final score from already aggregated statistics. 281 | 282 | :param stats: A list or numpy array of segment-level statistics. 283 | :return: A `Score` object. 284 | """ 285 | pass 286 | 287 | @abstractmethod 288 | def _preprocess_segment(self, sent: str) -> str: 289 | """A wrapper around the metric's tokenization and pre-processing logic. 290 | This should be implemented for reference caching to work correctly. 291 | 292 | :param sent: The input sentence. 293 | :return: The pre-processed output sentence. 294 | """ 295 | pass 296 | 297 | @abstractmethod 298 | def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, Any]: 299 | """Given a list of reference segments, extract the required 300 | information (such as n-grams for BLEU and chrF). This should be implemented 301 | for the generic `_cache_references()` to work across all metrics. 302 | 303 | :param refs: A sequence of strings. 304 | """ 305 | pass 306 | 307 | @abstractmethod 308 | def _compute_segment_statistics( 309 | self, hypothesis: str, ref_kwargs: Dict 310 | ) -> List[Any]: 311 | """Given a (pre-processed) hypothesis sentence and already computed 312 | reference info, returns the best match statistics across the 313 | references. The return type is usually a List of ints or floats. 314 | 315 | :param hypothesis: A pre-processed hypothesis sentence. 316 | :param ref_kwargs: A dictionary with reference-related information 317 | within. This is formulated as a dictionary as different metrics may 318 | require different information regarding a reference segment. 319 | """ 320 | pass 321 | 322 | def _cache_references(self, references: Sequence[Sequence[str]]) -> List[Any]: 323 | """Given the full set of document references, extract segment n-grams 324 | (or other necessary information) for caching purposes. 325 | 326 | :param references: A sequence of reference documents with document being 327 | defined as a sequence of reference strings. A particular reference 328 | segment can be '' or `None` to allow the use of variable number 329 | of references per segment. 330 | :return: A list where each element is a tuple of segment n-grams and 331 | reference lengths, as returned by `_extract_reference_info()`. 332 | """ 333 | ref_cache = [] 334 | 335 | # Decide on final number of refs here as well 336 | num_refs = set() 337 | 338 | for refs in zip(*references): 339 | # Remove undefined references 340 | lines = [x for x in refs if x is not None] 341 | 342 | # Keep track of reference counts to allow variable reference 343 | # info in the signature 344 | num_refs.add(len(lines)) 345 | 346 | lines = [self._preprocess_segment(x) for x in lines] 347 | 348 | # Get n-grams 349 | ref_cache.append(self._extract_reference_info(lines)) 350 | 351 | if len(num_refs) == 1: 352 | self.num_refs = list(num_refs)[0] 353 | else: 354 | # A variable number of refs exist 355 | self.num_refs = -1 356 | 357 | return ref_cache 358 | 359 | def _extract_corpus_statistics( 360 | self, hypotheses: Sequence[str], references: Optional[Sequence[Sequence[str]]] 361 | ) -> Any: 362 | """Reads the corpus and returns sentence-level match statistics for 363 | faster re-computations esp. during statistical tests. 364 | 365 | :param hypotheses: A sequence of hypothesis strings. 366 | :param references: A sequence of reference documents with document being 367 | defined as a sequence of reference strings. If `None`, cached references 368 | will be used. 369 | :return: A list where each sublist corresponds to segment statistics. 370 | """ 371 | # Pre-compute references 372 | # Don't store the cache as the user is explicitly passing refs 373 | if references: 374 | ref_cache = self._cache_references(references) 375 | elif self._ref_cache: 376 | ref_cache = self._ref_cache 377 | else: 378 | raise RuntimeError("No references provided and the cache is empty.") 379 | 380 | stats = [] 381 | tok_count = 0 382 | 383 | for hyp, ref_kwargs in zip(hypotheses, ref_cache): 384 | # Check for already-tokenized input problem (only for BLEU) 385 | if not self._force and hyp.endswith(" ."): 386 | tok_count += 1 387 | 388 | hyp = self._preprocess_segment(hyp) 389 | 390 | # Collect stats 391 | stats.append(self._compute_segment_statistics(hyp, ref_kwargs)) 392 | 393 | if tok_count >= 100: 394 | sacrelogger.warning("That's 100 lines that end in a tokenized period ('.')") 395 | sacrelogger.warning( 396 | "It looks like you forgot to detokenize your test data, which may hurt your score." 397 | ) 398 | sacrelogger.warning( 399 | "If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter." 400 | ) 401 | 402 | return stats 403 | 404 | def sentence_score(self, hypothesis: str, references: Sequence[str]) -> Any: 405 | """Compute the metric for a single sentence against a single (or multiple) reference(s). 406 | 407 | :param hypothesis: A single hypothesis string. 408 | :param references: A sequence of reference strings. 409 | :return: A `Score` object. 410 | """ 411 | self._check_sentence_score_args(hypothesis, references) 412 | 413 | stats = self._extract_corpus_statistics( 414 | [hypothesis], [[refs] for refs in references] 415 | ) 416 | return self._aggregate_and_compute(stats) 417 | 418 | def corpus_score( 419 | self, 420 | hypotheses: Sequence[str], 421 | references: Optional[Sequence[Sequence[str]]], 422 | n_bootstrap: int = 1, 423 | ) -> Any: 424 | """Compute the metric for a corpus against a single (or multiple) reference(s). 425 | 426 | :param hypotheses: A sequence of hypothesis strings. 427 | :param references: A sequence of reference documents with document being 428 | defined as a sequence of reference strings. If `None`, cached references 429 | will be used. 430 | :param n_bootstrap: If > 1, provides 95% confidence interval around true mean 431 | using bootstrap resampling with `n_bootstrap` samples. 432 | :return: A `Score` object. 433 | """ 434 | self._check_corpus_score_args(hypotheses, references) 435 | 436 | # Collect corpus stats 437 | stats = self._extract_corpus_statistics(hypotheses, references) 438 | 439 | # Compute the actual system score 440 | actual_score = self._aggregate_and_compute(stats) 441 | 442 | if n_bootstrap > 1: 443 | # Compute bootstrap estimate as well 444 | # Delayed import is to escape from numpy import if bootstrap 445 | # is not requested. 446 | from ..significance import _bootstrap_resample 447 | 448 | self.n_bootstrap = n_bootstrap 449 | self.seed, bs_scores = _bootstrap_resample(stats, self, n_bootstrap) 450 | actual_score.estimate_ci(bs_scores) 451 | 452 | return actual_score 453 | 454 | def get_signature(self) -> Signature: 455 | """Creates and returns the signature for the metric. The creation 456 | of signatures is delayed as the number of references is resolved 457 | only at the point of reference caching.""" 458 | return self._SIGNATURE_TYPE(self.__dict__) 459 | -------------------------------------------------------------------------------- /sacrebleu/metrics/bleu.py: -------------------------------------------------------------------------------- 1 | """The implementation of the BLEU metric (Papineni et al., 2002).""" 2 | 3 | import math 4 | import logging 5 | from importlib import import_module 6 | from typing import List, Sequence, Optional, Dict, Any 7 | 8 | from ..utils import my_log, sum_of_lists 9 | 10 | from .base import Score, Signature, Metric 11 | from .helpers import extract_all_word_ngrams 12 | 13 | sacrelogger = logging.getLogger('sacrebleu') 14 | 15 | # The default for the maximum n-gram order when computing precisions 16 | MAX_NGRAM_ORDER = 4 17 | 18 | _TOKENIZERS = { 19 | 'none': 'tokenizer_none.NoneTokenizer', 20 | 'zh': 'tokenizer_zh.TokenizerZh', 21 | '13a': 'tokenizer_13a.Tokenizer13a', 22 | 'intl': 'tokenizer_intl.TokenizerV14International', 23 | 'char': 'tokenizer_char.TokenizerChar', 24 | 'ja-mecab': 'tokenizer_ja_mecab.TokenizerJaMecab', 25 | 'ko-mecab': 'tokenizer_ko_mecab.TokenizerKoMecab', 26 | 'spm': 'tokenizer_spm.TokenizerSPM', 27 | 'flores101': 'tokenizer_spm.Flores101Tokenizer', 28 | 'flores200': 'tokenizer_spm.Flores200Tokenizer', 29 | ### Added for spBLEU-1K tokenizer by AbdelRahim Elmadany 30 | 'spBLEU-1K':'tokenizer_spm.spBLEU1KTokenizer', 31 | } 32 | 33 | 34 | def _get_tokenizer(name: str): 35 | """Dynamically import tokenizer as importing all is slow.""" 36 | module_name, class_name = _TOKENIZERS[name].rsplit('.', 1) 37 | return getattr( 38 | import_module(f'.tokenizers.{module_name}', 'sacrebleu'), 39 | class_name) 40 | 41 | 42 | class BLEUSignature(Signature): 43 | """A convenience class to represent the reproducibility signature for BLEU. 44 | 45 | :param args: key-value dictionary passed from the actual metric instance. 46 | """ 47 | def __init__(self, args: dict): 48 | """`BLEUSignature` initializer.""" 49 | super().__init__(args) 50 | 51 | self._abbr.update({ 52 | 'case': 'c', 53 | 'eff': 'e', 54 | 'tok': 'tok', 55 | 'smooth': 's', 56 | }) 57 | 58 | # Construct a combined string for smoothing method and value 59 | smooth_str = args['smooth_method'] 60 | smooth_def = BLEU.SMOOTH_DEFAULTS[smooth_str] 61 | 62 | # If the method requires a parameter, add it within brackets 63 | if smooth_def is not None: 64 | # the following can be None if the user wants to use the default 65 | smooth_val = args['smooth_value'] 66 | 67 | if smooth_val is None: 68 | smooth_val = smooth_def 69 | 70 | smooth_str += f'[{smooth_val:.2f}]' 71 | 72 | self.info.update({ 73 | 'case': 'lc' if args['lowercase'] else 'mixed', 74 | 'eff': 'yes' if args['effective_order'] else 'no', 75 | 'tok': args['tokenizer_signature'], 76 | 'smooth': smooth_str, 77 | }) 78 | 79 | 80 | class BLEUScore(Score): 81 | """A convenience class to represent BLEU scores. 82 | 83 | :param score: The BLEU score. 84 | :param counts: List of counts of correct ngrams, 1 <= n <= max_ngram_order 85 | :param totals: List of counts of total ngrams, 1 <= n <= max_ngram_order 86 | :param precisions: List of precisions, 1 <= n <= max_ngram_order 87 | :param bp: The brevity penalty. 88 | :param sys_len: The cumulative system length. 89 | :param ref_len: The cumulative reference length. 90 | """ 91 | def __init__(self, score: float, counts: List[int], totals: List[int], 92 | precisions: List[float], bp: float, 93 | sys_len: int, ref_len: int): 94 | """`BLEUScore` initializer.""" 95 | super().__init__('BLEU', score) 96 | self.bp = bp 97 | self.counts = counts 98 | self.totals = totals 99 | self.sys_len = sys_len 100 | self.ref_len = ref_len 101 | self.precisions = precisions 102 | 103 | self.prec_str = "/".join([f"{p:.1f}" for p in self.precisions]) 104 | self.ratio = self.sys_len / self.ref_len if self.ref_len else 0 105 | 106 | # The verbose part of BLEU 107 | self._verbose = f"{self.prec_str} (BP = {self.bp:.3f} " 108 | self._verbose += f"ratio = {self.ratio:.3f} hyp_len = {self.sys_len:d} " 109 | self._verbose += f"ref_len = {self.ref_len:d})" 110 | 111 | 112 | class BLEU(Metric): 113 | """Computes the BLEU metric given hypotheses and references. 114 | 115 | :param lowercase: If True, lowercased BLEU is computed. 116 | :param force: Ignore data that looks already tokenized. 117 | :param tokenize: The tokenizer to use. If None, defaults to language-specific tokenizers with '13a' as the fallback default. 118 | :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none'). 119 | :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. 120 | :param max_ngram_order: If given, it overrides the maximum n-gram order (default: 4) when computing precisions. 121 | :param effective_order: If `True`, stop including n-gram orders for which precision is 0. This should be 122 | `True`, if sentence-level BLEU will be computed. 123 | :param trg_lang: An optional language code to raise potential tokenizer warnings. 124 | :param references: A sequence of reference documents with document being 125 | defined as a sequence of reference strings. If given, the reference n-grams 126 | and lengths will be pre-computed and cached for faster BLEU computation 127 | across many systems. 128 | """ 129 | 130 | SMOOTH_DEFAULTS: Dict[str, Optional[float]] = { 131 | # The defaults for `floor` and `add-k` are obtained from the following paper 132 | # A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU 133 | # Boxing Chen and Colin Cherry 134 | # http://aclweb.org/anthology/W14-3346 135 | 'none': None, # No value is required 136 | 'floor': 0.1, 137 | 'add-k': 1, 138 | 'exp': None, # No value is required 139 | } 140 | 141 | TOKENIZERS = _TOKENIZERS.keys() 142 | 143 | # mteval-v13a.pl tokenizer unless Chinese or Japanese is provided 144 | TOKENIZER_DEFAULT = '13a' 145 | 146 | # Some language specific mappings to use if `trg_lang` is given 147 | # and the tokenizer is not explicitly specified 148 | _TOKENIZER_MAP = { 149 | 'zh': 'zh', 150 | 'ja': 'ja-mecab', 151 | 'ko': 'ko-mecab', 152 | } 153 | 154 | _SIGNATURE_TYPE = BLEUSignature 155 | 156 | def __init__(self, lowercase: bool = False, 157 | force: bool = False, 158 | tokenize: Optional[str] = None, 159 | smooth_method: str = 'exp', 160 | smooth_value: Optional[float] = None, 161 | max_ngram_order: int = MAX_NGRAM_ORDER, 162 | effective_order: bool = False, 163 | trg_lang: str = '', 164 | references: Optional[Sequence[Sequence[str]]] = None): 165 | """`BLEU` initializer.""" 166 | super().__init__() 167 | 168 | self._force = force 169 | self.trg_lang = trg_lang 170 | self.lowercase = lowercase 171 | self.smooth_value = smooth_value 172 | self.smooth_method = smooth_method 173 | self.max_ngram_order = max_ngram_order 174 | self.effective_order = effective_order 175 | 176 | # Sanity check 177 | assert self.smooth_method in self.SMOOTH_DEFAULTS.keys(), \ 178 | "Unknown smooth_method {self.smooth_method!r}" 179 | 180 | # If the tokenizer wasn't specified, choose it according to the 181 | # following logic. We use 'v13a' except for ZH and JA. Note that 182 | # this logic can only be applied when sacrebleu knows the target 183 | # language, which is only the case for builtin datasets. 184 | if tokenize is None: 185 | best_tokenizer = self.TOKENIZER_DEFAULT 186 | 187 | # Set `zh` or `ja-mecab` or `ko-mecab` if target language is provided 188 | if self.trg_lang in self._TOKENIZER_MAP: 189 | best_tokenizer = self._TOKENIZER_MAP[self.trg_lang] 190 | else: 191 | best_tokenizer = tokenize 192 | if self.trg_lang == 'zh' and best_tokenizer != 'zh': 193 | sacrelogger.warning( 194 | "Consider using the 'zh' or 'spm' tokenizer for Chinese.") 195 | if self.trg_lang == 'ja' and best_tokenizer != 'ja-mecab': 196 | sacrelogger.warning( 197 | "Consider using the 'ja-mecab' or 'spm' tokenizer for Japanese.") 198 | if self.trg_lang == 'ko' and best_tokenizer != 'ko-mecab': 199 | sacrelogger.warning( 200 | "Consider using the 'ko-mecab' or 'spm' tokenizer for Korean.") 201 | 202 | # Create the tokenizer 203 | self.tokenizer = _get_tokenizer(best_tokenizer)() 204 | 205 | # Build the signature 206 | self.tokenizer_signature = self.tokenizer.signature() 207 | 208 | if references is not None: 209 | # Pre-compute reference ngrams and lengths 210 | self._ref_cache = self._cache_references(references) 211 | 212 | @staticmethod 213 | def compute_bleu(correct: List[int], 214 | total: List[int], 215 | sys_len: int, 216 | ref_len: int, 217 | smooth_method: str = 'none', 218 | smooth_value=None, 219 | effective_order: bool = False, 220 | max_ngram_order: int = MAX_NGRAM_ORDER) -> BLEUScore: 221 | """Computes BLEU score from its sufficient statistics with smoothing. 222 | 223 | Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", 224 | Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) 225 | 226 | - none: No smoothing. 227 | - floor: Method 1 (requires small positive value (0.1 in the paper) to be set) 228 | - add-k: Method 2 (Generalizing Lin and Och, 2004) 229 | - exp: Method 3 (NIST smoothing method i.e. in use with mteval-v13a.pl) 230 | 231 | :param correct: List of counts of correct ngrams, 1 <= n <= max_ngram_order 232 | :param total: List of counts of total ngrams, 1 <= n <= max_ngram_order 233 | :param sys_len: The cumulative system length 234 | :param ref_len: The cumulative reference length 235 | :param smooth_method: The smoothing method to use ('floor', 'add-k', 'exp' or 'none') 236 | :param smooth_value: The smoothing value for `floor` and `add-k` methods. `None` falls back to default value. 237 | :param effective_order: If `True`, stop including n-gram orders for which precision is 0. This should be 238 | `True`, if sentence-level BLEU will be computed. 239 | :param max_ngram_order: If given, it overrides the maximum n-gram order (default: 4) when computing precisions. 240 | :return: A `BLEUScore` instance. 241 | """ 242 | assert smooth_method in BLEU.SMOOTH_DEFAULTS.keys(), \ 243 | "Unknown smooth_method {smooth_method!r}" 244 | 245 | # Fetch the default value for floor and add-k 246 | if smooth_value is None: 247 | smooth_value = BLEU.SMOOTH_DEFAULTS[smooth_method] 248 | 249 | # Compute brevity penalty 250 | bp = 1.0 251 | if sys_len < ref_len: 252 | bp = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 253 | 254 | # n-gram precisions 255 | precisions = [0.0] * max_ngram_order 256 | 257 | # Early stop if there are no matches (#141) 258 | if not any(correct): 259 | return BLEUScore(0.0, correct, total, precisions, bp, sys_len, ref_len) 260 | 261 | smooth_mteval = 1. 262 | eff_order = max_ngram_order 263 | for n in range(1, len(precisions) + 1): 264 | if smooth_method == 'add-k' and n > 1: 265 | correct[n - 1] += smooth_value 266 | total[n - 1] += smooth_value 267 | 268 | if total[n - 1] == 0: 269 | break 270 | 271 | # If the system guesses no i-grams, 1 <= i <= max_ngram_order, 272 | # the BLEU score is 0 (technically undefined). This is a problem for sentence 273 | # level BLEU or a corpus of short sentences, where systems will get 274 | # no credit if sentence lengths fall under the max_ngram_order threshold. 275 | # This fix scales max_ngram_order to the observed maximum order. 276 | # It is only available through the API and off by default 277 | if effective_order: 278 | eff_order = n 279 | 280 | if correct[n - 1] == 0: 281 | if smooth_method == 'exp': 282 | smooth_mteval *= 2 283 | precisions[n - 1] = 100. / (smooth_mteval * total[n - 1]) 284 | elif smooth_method == 'floor': 285 | precisions[n - 1] = 100. * smooth_value / total[n - 1] 286 | else: 287 | precisions[n - 1] = 100. * correct[n - 1] / total[n - 1] 288 | 289 | # Compute BLEU score 290 | score = bp * math.exp( 291 | sum([my_log(p) for p in precisions[:eff_order]]) / eff_order) 292 | 293 | return BLEUScore(score, correct, total, precisions, bp, sys_len, ref_len) 294 | 295 | def _preprocess_segment(self, sent: str) -> str: 296 | """Given a sentence, lowercases (optionally) and tokenizes it 297 | :param sent: The input sentence string. 298 | :return: The pre-processed output string. 299 | """ 300 | if self.lowercase: 301 | sent = sent.lower() 302 | return self.tokenizer(sent.rstrip()) 303 | 304 | def _compute_score_from_stats(self, stats: List[int]) -> BLEUScore: 305 | """Computes the final score from already aggregated statistics. 306 | 307 | :param stats: A list or numpy array of segment-level statistics. 308 | :return: A `BLEUScore` object. 309 | """ 310 | return self.compute_bleu( 311 | correct=stats[2: 2 + self.max_ngram_order], 312 | total=stats[2 + self.max_ngram_order:], 313 | sys_len=int(stats[0]), ref_len=int(stats[1]), 314 | smooth_method=self.smooth_method, smooth_value=self.smooth_value, 315 | effective_order=self.effective_order, 316 | max_ngram_order=self.max_ngram_order 317 | ) 318 | 319 | def _aggregate_and_compute(self, stats: List[List[int]]) -> BLEUScore: 320 | """Computes the final BLEU score given the pre-computed corpus statistics. 321 | 322 | :param stats: A list of segment-level statistics 323 | :return: A `BLEUScore` instance. 324 | """ 325 | return self._compute_score_from_stats(sum_of_lists(stats)) 326 | 327 | def _get_closest_ref_len(self, hyp_len: int, ref_lens: List[int]) -> int: 328 | """Given a hypothesis length and a list of reference lengths, returns 329 | the closest reference length to be used by BLEU. 330 | 331 | :param hyp_len: The hypothesis length. 332 | :param ref_lens: A list of reference lengths. 333 | :return: The closest reference length. 334 | """ 335 | closest_diff, closest_len = -1, -1 336 | 337 | for ref_len in ref_lens: 338 | diff = abs(hyp_len - ref_len) 339 | if closest_diff == -1 or diff < closest_diff: 340 | closest_diff = diff 341 | closest_len = ref_len 342 | elif diff == closest_diff and ref_len < closest_len: 343 | closest_len = ref_len 344 | 345 | return closest_len 346 | 347 | def _extract_reference_info(self, refs: Sequence[str]) -> Dict[str, Any]: 348 | """Given a list of reference segments, extract the n-grams and reference lengths. 349 | The latter will be useful when comparing hypothesis and reference lengths for BLEU. 350 | 351 | :param refs: A sequence of strings. 352 | :return: A dictionary that will be passed to `_compute_segment_statistics()` 353 | through keyword arguments. 354 | """ 355 | ngrams = None 356 | ref_lens = [] 357 | 358 | for ref in refs: 359 | # extract n-grams for this ref 360 | this_ngrams, ref_len = extract_all_word_ngrams(ref, 1, self.max_ngram_order) 361 | ref_lens.append(ref_len) 362 | 363 | if ngrams is None: 364 | # Set it directly for first set of refs 365 | ngrams = this_ngrams 366 | else: 367 | # Merge counts across multiple references 368 | # The below loop is faster than `ngrams |= this_ngrams` 369 | for ngram, count in this_ngrams.items(): 370 | ngrams[ngram] = max(ngrams[ngram], count) 371 | 372 | return {'ref_ngrams': ngrams, 'ref_lens': ref_lens} 373 | 374 | def _compute_segment_statistics(self, hypothesis: str, 375 | ref_kwargs: Dict) -> List[int]: 376 | """Given a (pre-processed) hypothesis sentence and already computed 377 | reference n-grams & lengths, returns the best match statistics across the 378 | references. 379 | 380 | :param hypothesis: Hypothesis sentence. 381 | :param ref_kwargs: A dictionary with `refs_ngrams`and `ref_lens` keys 382 | that denote the counter containing all n-gram counts and reference lengths, 383 | respectively. 384 | :return: A list of integers with match statistics. 385 | """ 386 | 387 | ref_ngrams, ref_lens = ref_kwargs['ref_ngrams'], ref_kwargs['ref_lens'] 388 | 389 | # Extract n-grams for the hypothesis 390 | hyp_ngrams, hyp_len = extract_all_word_ngrams( 391 | hypothesis, 1, self.max_ngram_order) 392 | 393 | ref_len = self._get_closest_ref_len(hyp_len, ref_lens) 394 | 395 | # Count the stats 396 | # Although counter has its internal & and | operators, this is faster 397 | correct = [0 for i in range(self.max_ngram_order)] 398 | total = correct[:] 399 | for hyp_ngram, hyp_count in hyp_ngrams.items(): 400 | # n-gram order 401 | n = len(hyp_ngram) - 1 402 | # count hypothesis n-grams 403 | total[n] += hyp_count 404 | # count matched n-grams 405 | if hyp_ngram in ref_ngrams: 406 | correct[n] += min(hyp_count, ref_ngrams[hyp_ngram]) 407 | 408 | # Return a flattened list for efficient computation 409 | return [hyp_len, ref_len] + correct + total 410 | 411 | def sentence_score(self, hypothesis: str, references: Sequence[str]) -> BLEUScore: 412 | """Compute the metric for a single sentence against a single (or multiple) reference(s). 413 | 414 | :param hypothesis: A single hypothesis string. 415 | :param references: A sequence of reference strings. 416 | :return: a `BLEUScore` object. 417 | """ 418 | if not self.effective_order: 419 | sacrelogger.warning( 420 | 'It is recommended to enable `effective_order` for sentence-level BLEU.') 421 | return super().sentence_score(hypothesis, references) 422 | -------------------------------------------------------------------------------- /sacrebleu/significance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import multiprocessing as mp 4 | from typing import Sequence, Dict, Optional, Tuple, List, Union, Any, Mapping 5 | 6 | import numpy as np 7 | 8 | from .metrics.base import Metric, Score, Signature 9 | 10 | IS_WINDOWS = os.name == 'nt' 11 | 12 | 13 | sacrelogger = logging.getLogger('sacrebleu') 14 | 15 | 16 | class Result: 17 | """A container to represent results from a particular statistical 18 | significance test. 19 | :param score: The floating point score for the system at hand. 20 | :param p_value: If exists, represents the p-value when the system at 21 | hand is compared to a baseline using a paired test. 22 | :param mean: When paired bootstrap test is applied, this represents 23 | the true mean score estimated from bootstrap resamples of the system. 24 | :param ci: When paired bootstrap test is applied, this represents 25 | the 95% confidence interval around the true mean score `sys_mean`. 26 | """ 27 | def __init__(self, score: float, p_value: Optional[float] = None, 28 | mean: Optional[float] = None, ci: Optional[float] = None): 29 | self.score = score 30 | self.p_value = p_value 31 | self.mean = mean 32 | self.ci = ci 33 | 34 | def __repr__(self): 35 | return ','.join([f'{k}={str(v)}' for k, v in self.__dict__.items()]) 36 | 37 | 38 | def estimate_ci(scores: np.ndarray) -> Tuple[float, float]: 39 | """Takes a list of scores and returns mean and 95% confidence 40 | interval around the mean. 41 | 42 | :param scores: A list of floating point scores. 43 | :return: A tuple of mean and the 95% CI. 44 | """ 45 | # Sort the scores 46 | scores = np.sort(scores) 47 | n = len(scores) 48 | 49 | # Get CI bounds (95%, i.e. 1/40 from left) 50 | lower_idx = n // 40 51 | upper_idx = n - lower_idx - 1 52 | lower, upper = scores[lower_idx], scores[upper_idx] 53 | ci = 0.5 * (upper - lower) 54 | return (scores.mean(), ci) 55 | 56 | 57 | def _bootstrap_resample(stats: List[List[Union[int, float]]], 58 | metric: Metric, n_samples: int = 1000) -> Tuple[str, List[Score]]: 59 | """Performs bootstrap resampling for a single system to estimate 60 | a confidence interval around the true mean. 61 | :param stats: A list of statistics extracted from the system's hypotheses. 62 | :param metric: The `Metric` instance to be used for score computation. 63 | :n_samples: Number of bootstrap resamples to use. 64 | 65 | :return: A tuple of the seed choice as string and the list of `Score` 66 | instances for all bootstrap resamples. 67 | """ 68 | 69 | # Set numpy RNG's seed 70 | # If given -> Fix to the given value 71 | # If given but =='[Nn]one', don't fix the seed i.e. pull entropy from OS 72 | seed = os.environ.get('SACREBLEU_SEED', '12345') 73 | _seed = None if seed.lower() == 'none' else int(seed) 74 | rng = np.random.default_rng(_seed) 75 | 76 | # The indices that'll produce all bootstrap resamples at once 77 | idxs = rng.choice(len(stats), size=(n_samples, len(stats)), replace=True) 78 | 79 | # convert to numpy array. float32 is more efficient 80 | stats_np = np.array(stats, dtype='float32') 81 | 82 | # recompute scores for all resamples 83 | scores = [ 84 | metric._compute_score_from_stats(_s.sum(0)) for _s in stats_np[idxs]] 85 | 86 | return str(seed).lower(), scores 87 | 88 | 89 | def _compute_p_value(stats: np.ndarray, real_difference: float) -> float: 90 | """Computes the p-value given the sample statistics and the real statistic. 91 | :param stats: A numpy array with the sample statistics. 92 | :real_difference: The real statistic. 93 | :return: The p-value. 94 | """ 95 | # Taken from: significance/StratifiedApproximateRandomizationTest.java 96 | # https://github.com/jhclark/multeval.git 97 | 98 | # "the != is important. if we want to score the same system against itself 99 | # having a zero difference should not be attributed to chance." 100 | 101 | c = np.sum(stats > real_difference).item() 102 | 103 | # "+1 applies here, though it only matters for small numbers of shufflings, 104 | # which we typically never do. it's necessary to ensure the probability of 105 | # falsely rejecting the null hypothesis is no greater than the rejection 106 | # level of the test (see william and morgan on significance tests) 107 | p = (c + 1) / (len(stats) + 1) 108 | 109 | return p 110 | 111 | 112 | def _paired_ar_test(baseline_info: Dict[str, Tuple[np.ndarray, Result]], 113 | sys_name: str, 114 | hypotheses: Sequence[str], 115 | references: Optional[Sequence[Sequence[str]]], 116 | metrics: Dict[str, Metric], 117 | n_samples: int = 10000, 118 | n_ar_confidence: int = -1, 119 | seed: Optional[int] = None) -> Tuple[str, Dict[str, Result]]: 120 | """Paired two-sided approximate randomization (AR) test for MT evaluation. 121 | 122 | :param baseline_info: A dictionary with `Metric` instances as the keys, 123 | that contains sufficient statistics and a `Result` instance for the baseline system. 124 | :param sys_name: The name of the system to be evaluated. 125 | :param hypotheses: A sequence of string hypotheses for the system. 126 | :param references: A sequence of reference documents with document being 127 | defined as a sequence of reference strings. If `None`, references 128 | will be used through each metric's internal cache. 129 | :param metrics: A dictionary of `Metric` instances that will be computed 130 | for each system. 131 | :param n_samples: The number of AR trials. 132 | :param n_ar_confidence: The number of bootstrap resamples to use for 133 | confidence estimation. A value of -1 disables confidence estimation. 134 | :param seed: The seed value for the RNG. If `None`, the RNG will not be 135 | fixed to a particular seed. 136 | 137 | :return: A tuple with first element being the system name and the second 138 | being a `Result` namedtuple. 139 | """ 140 | # Seed the RNG 141 | rng = np.random.default_rng(seed) 142 | 143 | # Generate indices that'll select stats 144 | pos_sel = rng.integers(2, size=(n_samples, len(hypotheses)), dtype=bool) 145 | 146 | # Flip mask to obtain selectors for system hypotheses 147 | neg_sel = ~pos_sel 148 | 149 | if n_ar_confidence > 0: 150 | # Perform confidence estimation as well 151 | bs_idxs = rng.choice( 152 | len(hypotheses), size=(n_ar_confidence, len(hypotheses)), replace=True) 153 | 154 | results = {} 155 | 156 | for name, metric in metrics.items(): 157 | # Use pre-computed match stats for the baseline 158 | bl_stats, bl_result = baseline_info[name] 159 | 160 | # Compute system's stats and score 161 | sacrelogger.info(f'Computing {name} for {sys_name!r} and extracting sufficient statistics') 162 | sys_stats = metric._extract_corpus_statistics(hypotheses, references) 163 | sys_score = metric._aggregate_and_compute(sys_stats) 164 | 165 | # original test statistic: absolute difference between baseline and the system 166 | diff = abs(bl_result.score - sys_score.score) 167 | 168 | sacrelogger.info(f' > Performing approximate randomization test (# trials: {n_samples})') 169 | # get shuffled pseudo systems 170 | shuf_a = pos_sel @ bl_stats + neg_sel @ sys_stats 171 | shuf_b = neg_sel @ bl_stats + pos_sel @ sys_stats 172 | 173 | # Aggregate trial stats and compute scores for each 174 | scores_a = np.array( 175 | [metric._aggregate_and_compute(x).score for x in shuf_a[:, None]]) 176 | scores_b = np.array( 177 | [metric._aggregate_and_compute(x).score for x in shuf_b[:, None]]) 178 | 179 | # Count the statistical difference and compute the p-value 180 | p = _compute_p_value( 181 | np.abs(np.array(scores_a) - np.array(scores_b)), diff) 182 | 183 | res = Result(sys_score.score, p) 184 | 185 | if n_ar_confidence > 0: 186 | sacrelogger.info(f' > Performing bootstrap resampling for confidence interval (# resamples: {n_ar_confidence})') 187 | sys_stats = np.array(sys_stats, dtype='float32') 188 | # recompute scores for all resamples 189 | sys_scores = np.array([ 190 | metric._compute_score_from_stats(_s.sum(0)).score for _s in sys_stats[bs_idxs] 191 | ]) 192 | res.mean, res.ci = estimate_ci(sys_scores) 193 | 194 | # Store the result 195 | results[name] = res 196 | 197 | return sys_name, results 198 | 199 | 200 | def _paired_bs_test(baseline_info: Dict[str, Tuple[np.ndarray, Result]], 201 | sys_name: str, 202 | hypotheses: Sequence[str], 203 | references: Optional[Sequence[Sequence[str]]], 204 | metrics: Dict[str, Metric], 205 | n_samples: int = 1000, 206 | n_ar_confidence: int = -1, 207 | seed: Optional[int] = None) -> Tuple[str, Dict[str, Result]]: 208 | """Paired bootstrap resampling test for MT evaluation. This function 209 | replicates the behavior of the Moses script called 210 | `bootstrap-hypothesis-difference-significance.pl`. 211 | 212 | :param baseline_info: A dictionary with `Metric` instances as the keys, 213 | that contains sufficient statistics and a `Result` instance for the baseline system. 214 | :param sys_name: The name of the system to be evaluated. 215 | :param hypotheses: A sequence of string hypotheses for the system. 216 | :param references: A sequence of reference documents with document being 217 | defined as a sequence of reference strings. If `None`, references 218 | will be used through each metric's internal cache. 219 | :param metrics: A dictionary of `Metric` instances that will be computed 220 | for each system. 221 | :param n_samples: The number of bootstrap resamples. 222 | :param n_ar_confidence: This parameter is not used for this function but 223 | is there for signature compatibility in the API. 224 | :param seed: The seed value for the RNG. If `None`, the RNG will not be 225 | fixed to a particular seed. 226 | 227 | :return: A tuple with first element being the system name and the second 228 | being a `Result` namedtuple. 229 | """ 230 | # Seed the RNG 231 | rng = np.random.default_rng(seed) 232 | 233 | results = {} 234 | 235 | # It takes ~10ms to generated the indices 236 | idxs = rng.choice( 237 | len(hypotheses), size=(n_samples, len(hypotheses)), replace=True) 238 | 239 | for name, metric in metrics.items(): 240 | # Use pre-computed match stats for the baseline 241 | bl_stats, bl_result = baseline_info[name] 242 | 243 | # Compute system's stats and score 244 | sacrelogger.info(f'Computing {name} for {sys_name!r} and extracting sufficient statistics') 245 | sys_stats = metric._extract_corpus_statistics(hypotheses, references) 246 | sys_score = metric._aggregate_and_compute(sys_stats) 247 | 248 | # Convert to numpy arrays for efficient indexing 249 | sys_stats = np.array(sys_stats, dtype='float32') 250 | bl_stats = np.array(bl_stats, dtype='float32') 251 | 252 | # original test statistic: absolute difference between baseline and the system 253 | diff = abs(bl_result.score - sys_score.score) 254 | 255 | sacrelogger.info(f' > Performing paired bootstrap resampling test (# resamples: {n_samples})') 256 | scores_bl = np.array( 257 | [metric._compute_score_from_stats(_s.sum(0)).score for _s in bl_stats[idxs]]) 258 | scores_sys = np.array( 259 | [metric._compute_score_from_stats(_s.sum(0)).score for _s in sys_stats[idxs]]) 260 | 261 | # Compute CI as well 262 | sys_mean, sys_ci = estimate_ci(scores_sys) 263 | 264 | # Compute the statistics 265 | sample_diffs = np.abs(scores_sys - scores_bl) 266 | stats = sample_diffs - sample_diffs.mean() 267 | 268 | # Count the statistical difference and compute the p-value 269 | p = _compute_p_value(stats, diff) 270 | 271 | results[name] = Result(sys_score.score, p, sys_mean, sys_ci) 272 | 273 | return sys_name, results 274 | 275 | 276 | class PairedTest: 277 | """This is the manager class that will call the actual standalone implementation 278 | for approximate randomization or paired bootstrap resampling, based on the 279 | `test_type` argument. 280 | 281 | :param named_systems: A lisf of (system_name, system_hypotheses) tuples on 282 | which the test will be applied. 283 | :param metrics: A dictionary of `Metric` instances that will be computed 284 | for each system. 285 | :param references: A sequence of reference documents with document being 286 | defined as a sequence of reference strings. If `None`, already cached references 287 | will be used through each metric's internal cache. 288 | :param test_type: `ar` for approximate randomization, `bs` for paired bootstrap. 289 | :param n_samples: The number of AR trials (for `ar`) or bootstrap resamples (for `bs`). 290 | The defaults (10000 or 1000 respectively) will be used if 0 is passed. 291 | :param n_ar_confidence: If `approximate randomization` is selected, the number 292 | of bootstrap resamples to use for confidence estimation. A value of -1 disables 293 | confidence estimation. 0 will use the default of 1000. 294 | :param n_jobs: If 0, a worker process will be spawned for each system variant. 295 | If > 0, the number of workers will be set accordingly. The default of 1 296 | does not use multi-processing. 297 | """ 298 | _DEFAULT_SAMPLES = { 299 | 'ar': 10000, 300 | 'bs': 1000, 301 | } 302 | 303 | def __init__(self, named_systems: List[Tuple[str, Sequence[str]]], 304 | metrics: Mapping[str, Metric], 305 | references: Optional[Sequence[Sequence[str]]], 306 | test_type: str = 'ar', 307 | n_samples: int = 0, 308 | n_ar_confidence: int = -1, 309 | n_jobs: int = 1): 310 | assert test_type in ('ar', 'bs'), f"Unknown test type {test_type!r}" 311 | self.test_type = test_type 312 | 313 | # Set method 314 | if self.test_type == 'ar': 315 | self._fn = _paired_ar_test 316 | elif self.test_type == 'bs': 317 | self._fn = _paired_bs_test 318 | 319 | # Set numpy RNG's seed 320 | # If given -> Fix to the given value 321 | # If given but =='[Nn]one', don't fix the seed i.e. pull entropy from OS 322 | seed = os.environ.get('SACREBLEU_SEED', '12345') 323 | self._seed = None if seed.lower() == 'none' else int(seed) 324 | self.n_jobs = n_jobs 325 | self.references = references 326 | self.named_systems = named_systems 327 | 328 | # Set the defaults if requested 329 | self.n_ar_confidence = n_ar_confidence if n_ar_confidence != 0 else \ 330 | self._DEFAULT_SAMPLES['bs'] 331 | 332 | self.n_samples = n_samples if n_samples > 0 else \ 333 | self._DEFAULT_SAMPLES[self.test_type] 334 | 335 | # Number of systems (excluding the baseline) 336 | self.n_systems = len(named_systems) - 1 337 | 338 | # Decide on number of workers 339 | if IS_WINDOWS: 340 | sacrelogger.warning('Parallel tests are not supported on Windows.') 341 | self.n_jobs = 1 342 | elif self.n_jobs == 0: 343 | # Decide automatically 344 | # Divide by two to ignore hyper-threading 345 | n_max_jobs = mp.cpu_count() // 2 346 | if n_max_jobs == 0: 347 | self.n_jobs = 1 348 | else: 349 | # Don't use more workers than the number of CPUs 350 | self.n_jobs = min(n_max_jobs, self.n_systems) 351 | 352 | self._signatures: Dict[str, Signature] = {} 353 | self._baseline_info: Dict[str, Tuple[Any, Result]] = {} 354 | 355 | ################################################## 356 | # Pre-compute and cache baseline system statistics 357 | ################################################## 358 | self.metrics = {} 359 | 360 | bl_name, bl_hyps = self.named_systems[0] 361 | 362 | for name, metric in metrics.items(): 363 | sacrelogger.info(f'Pre-computing {name} statistics for {bl_name!r}') 364 | bl_stats = metric._extract_corpus_statistics(bl_hyps, self.references) 365 | bl_score = metric._aggregate_and_compute(bl_stats) 366 | 367 | # Compute CI for the baseline here once 368 | confidence_n = self.n_samples if self.test_type == 'bs' \ 369 | else self.n_ar_confidence 370 | 371 | bl_mean, bl_ci = None, None 372 | if confidence_n > 0: 373 | _, bl_scores = _bootstrap_resample(bl_stats, metric, confidence_n) 374 | bl_mean, bl_ci = estimate_ci(np.array([x.score for x in bl_scores])) 375 | 376 | result = Result(bl_score.score, mean=bl_mean, ci=bl_ci) 377 | # Use updated name for the metric 378 | self._baseline_info[bl_score.name] = (bl_stats, result) 379 | self.metrics[bl_score.name] = metric 380 | 381 | # Update metric signature as well 382 | sig = metric.get_signature() 383 | sig.update('seed', str(self._seed).lower()) 384 | 385 | # Num samples for bs, num trials for AR 386 | sig.update(self.test_type, self.n_samples) 387 | if self.n_ar_confidence > 0: 388 | # Bootstrap is used for AR CI as well 389 | sig.update('bs', self.n_ar_confidence) 390 | self._signatures[bl_score.name] = sig 391 | 392 | def __call__(self) -> Tuple[Dict[str, Signature], Dict[str, List[Union[str, Result]]]]: 393 | """Runs the paired test either on single or multiple worker processes.""" 394 | tasks = [] 395 | scores: Dict[str, List[Union[str, Result]]] = {} 396 | 397 | # Add the name column 398 | scores['System'] = [ns[0] for ns in self.named_systems] 399 | 400 | # Store baseline results as the first position 401 | for metric, (_, result) in self._baseline_info.items(): 402 | scores[metric] = [result] 403 | 404 | # Prepare list of arguments for each comparison 405 | # Skip the baseline (pos: 0) 406 | for idx, (name, hyps) in enumerate(self.named_systems[1:]): 407 | seed = self._seed if self._seed else None 408 | 409 | tasks.append( 410 | (self._baseline_info, name, hyps, self.references, 411 | self.metrics, self.n_samples, self.n_ar_confidence, seed)) 412 | 413 | # Run the test(s) 414 | if self.n_jobs == 1: 415 | results = [self._fn(*args) for args in tasks] 416 | else: 417 | # NOTE: The overhead of worker creation is not negligible 418 | # but if you have many systems and TER enabled, this significantly 419 | # speeds up the test. 420 | # NOTE: This only works on Linux/Mac OS X but not Windows. Windows only 421 | # supports `spawn` backend which requires things to be called 422 | # from within __main__. 423 | sacrelogger.info(f'Launching {self.n_jobs} parallel workers.') 424 | with mp.get_context('fork').Pool(self.n_jobs) as pool: 425 | jobs = [pool.apply_async(self._fn, args) for args in tasks] 426 | 427 | # wait for completion 428 | results = [j.get() for j in jobs] 429 | 430 | # Keep the order deterministic 431 | for sys_name, sys_results in results: 432 | for metric, _result in sys_results.items(): 433 | scores[metric].append(_result) 434 | 435 | return self._signatures, scores 436 | --------------------------------------------------------------------------------