├── .circleci └── config.yml ├── .editorconfig ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config ├── bert.yaml ├── data │ ├── wmt19.qe.en_de.yaml │ ├── wmt20.qe.en_de.yaml │ └── wmt20.qe.en_zh.yaml ├── evaluate.yaml ├── nuqe.yaml ├── predict.yaml ├── predictor.yaml ├── predictor_estimator.yaml ├── search.yaml ├── xlm.yaml └── xlmroberta.yaml ├── docs ├── _static │ └── img │ │ ├── openkiwi-logo-horizontal-dark.svg │ │ ├── openkiwi-logo-horizontal.svg │ │ ├── openkiwi-logo-icon-dark.svg │ │ ├── openkiwi-logo-icon.ico │ │ ├── openkiwi-logo-icon.svg │ │ ├── openkiwi-logo-vertical-dark.svg │ │ └── openkiwi-logo-vertical.svg ├── conf.py ├── configuration │ ├── evaluate.rst │ ├── index.rst │ ├── predict.rst │ ├── search.rst │ └── train.rst ├── index.rst ├── installation.rst ├── systems │ └── index.rst ├── test.json └── usage.rst ├── kiwi ├── __init__.py ├── __main__.py ├── assets │ ├── __init__.py │ └── config │ │ ├── __init__.py │ │ ├── evaluate.yaml │ │ ├── predict.yaml │ │ ├── pretrain.yaml │ │ ├── search.yaml │ │ └── train.yaml ├── cli.py ├── constants.py ├── data │ ├── __init__.py │ ├── batch.py │ ├── datasets │ │ ├── __init__.py │ │ ├── parallel_dataset.py │ │ └── wmt_qe_dataset.py │ ├── encoders │ │ ├── __init__.py │ │ ├── base.py │ │ ├── field_encoders.py │ │ ├── parallel_data_encoder.py │ │ └── wmt_qe_data_encoder.py │ ├── tokenizers.py │ └── vocabulary.py ├── lib │ ├── __init__.py │ ├── evaluate.py │ ├── predict.py │ ├── pretrain.py │ ├── search.py │ ├── train.py │ └── utils.py ├── loggers.py ├── metrics │ ├── __init__.py │ ├── functions.py │ └── metrics.py ├── modules │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── attention.py │ │ ├── distributions.py │ │ ├── feedforward.py │ │ ├── layer_norm.py │ │ ├── positional_encoding.py │ │ ├── scalar_mix.py │ │ └── scorer.py │ ├── sentence_level_output.py │ ├── token_embeddings.py │ └── word_level_output.py ├── runner.py ├── systems │ ├── __init__.py │ ├── _meta_module.py │ ├── bert.py │ ├── decoders │ │ ├── __init__.py │ │ ├── estimator.py │ │ ├── linear.py │ │ └── nuqe.py │ ├── encoders │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── predictor.py │ │ ├── quetch.py │ │ ├── xlm.py │ │ └── xlmroberta.py │ ├── nuqe.py │ ├── outputs │ │ ├── __init__.py │ │ ├── quality_estimation.py │ │ └── translation_language_model.py │ ├── predictor.py │ ├── predictor_estimator.py │ ├── qe_system.py │ ├── tlm_system.py │ ├── xlm.py │ └── xlmroberta.py ├── training │ ├── __init__.py │ ├── callbacks.py │ └── optimizers.py └── utils │ ├── __init__.py │ ├── data_structures.py │ ├── io.py │ └── tensors.py ├── poetry.lock ├── pyproject.toml ├── scripts └── merge_target_and_gaps_preds.py ├── tests ├── conftest.py ├── mocks │ ├── mock_vocab.py │ └── simple_model.py ├── test_bert.py ├── test_cli.py ├── test_data.py ├── test_metrics.py ├── test_nuqe.py ├── test_predict_and_evaluate.py ├── test_predictor.py ├── test_search.py ├── test_utils.py ├── test_xlm.py ├── test_xlmr.py └── toy-data │ ├── WMT18 │ └── word_level │ │ └── en_de.nmt │ │ ├── dev.hter │ │ ├── dev.mt │ │ ├── dev.pe │ │ ├── dev.ref │ │ ├── dev.src │ │ ├── dev.src-mt.alignments │ │ ├── dev.src_tags │ │ ├── dev.tags │ │ ├── train.hter │ │ ├── train.mt │ │ ├── train.pe │ │ ├── train.ref │ │ ├── train.src │ │ ├── train.src-mt.alignments │ │ ├── train.src_tags │ │ └── train.tags │ └── models │ ├── bert │ └── vocab.txt │ └── nuqe.ckpt └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | max_line_length = 88 13 | 14 | [*.json] 15 | indent_size = 2 16 | 17 | [*.bat] 18 | indent_style = tab 19 | end_of_line = crlf 20 | 21 | [LICENSE] 22 | insert_final_newline = false 23 | 24 | [Makefile] 25 | indent_style = tab 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Use data '...' 17 | 3. Run '....' with arguments '...' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Environment (please complete the following information):** 27 | - OS: [e.g. Linux] 28 | - OpenKiwi version [e.g. 0.1.0] 29 | - Python version 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Kiwi related stuff 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | /lib/ 20 | /lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # vscode 109 | .vscode 110 | 111 | # Emacs Autosave 112 | *~ 113 | \#*\# 114 | .\#* 115 | 116 | # Data Folder 117 | /data/ 118 | 119 | # Logs Folder 120 | log/ 121 | 122 | # torchtext vectors cache 123 | .vector_cache 124 | 125 | # MLFlow directory 126 | mlruns/ 127 | 128 | # local run directory 129 | runs/ 130 | 131 | # Pycharm project 132 | .idea/ 133 | 134 | # Mac ignore 135 | .DS_Store 136 | 137 | # Folder used for save predictions 138 | /predictions/ 139 | 140 | # Local shell scripts 141 | run_pipeline*.sh 142 | 143 | # Runners Directory Containing slurm output 144 | runners/ 145 | 146 | #vim files 147 | *.swp 148 | 149 | # Generated docs 150 | /docs/html/ 151 | /docs/.doctrees 152 | 153 | # Wheel Metadata 154 | pip-wheel-metadata/ 155 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | # Changelog 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [2.1.0](https://github.com/Unbabel/OpenKiwi/compare/2.0.0...2.1.0) - 2020-11-12 9 | 10 | ### Added 11 | - Hyperparameter search pipeline `kiwi search` built on [Optuna](https://optuna.readthedocs.io/) 12 | - Docs for the search pipeline 13 | - The `--example` flag that has and example config from `kiwi/assests/conf/` printed to terminal for each pipeline 14 | - Tests to increase coverage 15 | - Readme link to the new [OpenKiwiTasting](https://github.com/Unbabel/OpenKiwiTasting) demo. 16 | 17 | ### Changed 18 | - Example configs in `conf/` so that they are clean, consistent, and have good defaults 19 | - Moved function `feedforward` from `kiwi.tensors` to `kiwi.modules.common.feedforward` where it makes more sense 20 | 21 | ### Fixed 22 | - The broken relative links in the docs 23 | - Evaluation pipeline by adding missing `quiet` and `verbose` in the evaluate configuration 24 | 25 | ### Deprecated 26 | - Migration of models from a previous OpenKiwi version, by removing the (never fully working) code in `kiwi.utils.migrations` entirely 27 | 28 | ### Removed 29 | - Unused code in `kiwi.training.optimizers`, `kiwi.modules.common.scorer`, `kiwi.modules.common.layer_norm`, `kiwi.modules.sentence_level_output`, `kiwi.metrics.metrics`, `kiwi.modules.common.attention`, `kiwi.modules.token_embeddings` 30 | - _All_ code that was already commented out 31 | - The `systems.encoder.(predictor|bert|xlm|xlmrobera).encode_source` option that is both _confusing_ as well as _never used_ 32 | 33 | ## [2.0.0](https://github.com/Unbabel/OpenKiwi/compare/0.1.3...2.0.0) 34 | 35 | ### Added 36 | - XLMR, XLM, BERT encoder models 37 | - New pooling methods for xlmr-encoder [mixed, mean, ll_mean] 38 | - `freeze_for_number_of_steps` allows freezing of xlmr-encoder for a specific number of training steps 39 | - `encoder_learning_rate` allows to set a specific learning rate to be used on the encoder (different from the rest of the system) 40 | - Dataloaders now use a RandomBucketSampler which groups sentences of the same size together to minimize padding 41 | - fp16 support 42 | - Support for HuggingFace's transformers models 43 | - Pytorch-Lightning as a training framework 44 | - This changelog 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution guide 2 | 3 | ## Overview 4 | 5 | OpenKiwi is an Open Source Quality Estimation toolkit aimed at implementing state of the art models in an efficient and unified fashion. While we do welcome contributions, in order to guarantee their quality and usefulness, it is necessary that we follow basic guidelines in order to ease development, collaboration and readability. 6 | 7 | ## Basic guidelines 8 | 9 | * The project must fully support Python 3.5 or further. 10 | * Code is linted with [flake8](http://flake8.pycqa.org/en/latest/user/error-codes.html), please run `flake8 kiwi` and fix remaining errors before pushing any code. 11 | * Code formatting must stick to the Facebook style, 80 columns and single quotes. For Python 3.6+, the [black](https://github.com/ambv/black) formatter can be used by running `Black kiwi`. For python 3.5, [YAPF](https://github.com/google/yapf) should get most of the job done, although some manual changes might be necessary. 12 | * Imports are sorted with [isort](https://github.com/timothycrosley/isort). 13 | * Filenames must be in lowercase. 14 | * Tests are running with [pytest](https://docs.pytest.org/en/latest/) which is commonly referred to the best unittesting framework out there. Pytest implements a standard test discovery which means that it will only search for `test_*.py` or `*_test.py` files. We do not enforce a minimum code coverage but it is preferrable to have even very basic tests running for critical pieces of code. Always test functions that takes/returns tensor argument to document the sizes. 15 | * The `kiwi` folder contains core features. Any script calling these features must be placed into the `scripts` folder. 16 | 17 | ## Contributing 18 | 19 | * Keep track of everything by creating issues and editing them with reference to the code! Explain succinctly the problem you are trying to solve and your solution. 20 | * Contributions to `master` should be made through github pull-requests. 21 | * Dependencies are managed using `Poetry`. Although we would rather err on the side of less rather than more dependencies, if needed they are managed through the `pyproject.toml` file. 22 | * Work in a clean environment (`virtualenv` is nice). 23 | * Your commit message must start with an infinitive verb (Add, Fix, Remove, ...). 24 | * If your change is based on a paper, please include a clear comment and reference in the code and in the related issue. 25 | * In order to test your local changes, install OpenKiwi following the instructions on the [documentation](https://unbabel.github.io/openkiwi) 26 | -------------------------------------------------------------------------------- /config/bert.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | ################################################### 3 | # Generic configurations options related to 4 | # handling of experiments 5 | experiment_name: BERT WMT20 EN-DE 6 | seed: 42 7 | use_mlflow: false 8 | 9 | trainer: 10 | ################################################### 11 | # Generic options related to the training process 12 | # that apply to all models 13 | deterministic: true 14 | gpus: -1 15 | epochs: 10 16 | 17 | main_metric: 18 | - WMT19_MCC 19 | - PEARSON 20 | 21 | gradient_max_norm: 1. 22 | gradient_accumulation_steps: 1 23 | 24 | # Control the model precision, see 25 | # https://pytorch-lightning.readthedocs.io/en/stable/amp.html 26 | # for more info on the configuration options 27 | # Fast mixed precision by default: 28 | amp_level: O2 29 | precision: 16 30 | 31 | log_interval: 100 32 | checkpoint: 33 | validation_steps: 0.2 34 | early_stop_patience: 10 35 | 36 | 37 | defaults: 38 | ################################################### 39 | # Example of composition of configuration files 40 | # this config is sourced from /config/data/wmt20.qe.en_de.yaml 41 | - data: wmt20.qe.en_de 42 | 43 | system: 44 | ################################################### 45 | # System configs are responsible for all the system 46 | # specific configurations. From model settings to 47 | # optimizers and specific processing options. 48 | 49 | # All configs must have either `class_name` or `load` 50 | class_name: Bert 51 | 52 | batch_size: 8 53 | num_data_workers: 4 54 | 55 | model: 56 | ################################################ 57 | # Modeling options. These can change a lot about 58 | # the architecture of the system. With many configuration 59 | # options adding (or removing) layers. 60 | encoder: 61 | model_name: bert-base-multilingual-cased 62 | use_mlp: false 63 | freeze: false 64 | 65 | decoder: 66 | hidden_size: 768 67 | bottleneck_size: 768 68 | dropout: 0.1 69 | 70 | outputs: 71 | #################################################### 72 | # Output options configure the downstream tasks the 73 | # model will be trained on by adding specific layers 74 | # responsible for transforming decoder features into 75 | # predictions. 76 | word_level: 77 | target: true 78 | gaps: true 79 | source: true 80 | class_weights: 81 | target_tags: 82 | BAD: 3.0 83 | gap_tags: 84 | BAD: 5.0 85 | source_tags: 86 | BAD: 3.0 87 | sentence_level: 88 | hter: true 89 | use_distribution: false 90 | binary: false 91 | n_layers_output: 2 92 | sentence_loss_weight: 1 93 | 94 | tlm_outputs: 95 | fine_tune: false 96 | 97 | optimizer: 98 | class_name: adamw 99 | learning_rate: 1e-5 100 | warmup_steps: 0.1 101 | training_steps: 12000 102 | 103 | data_processing: 104 | share_input_fields_encoders: true 105 | -------------------------------------------------------------------------------- /config/data/wmt19.qe.en_de.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train: 3 | input: 4 | source: data/WMT19/en_de.nmt/train.src 5 | target: data/WMT19/en_de.nmt/train.mt 6 | alignments: data/WMT19/en_de.nmt/train.src-mt.alignments 7 | post_edit: data/WMT19/en_de.nmt/train.pe 8 | 9 | output: 10 | source_tags: data/WMT19/en_de.nmt/train.source_tags 11 | target_tags: data/WMT19/en_de.nmt/train.tags 12 | sentence_scores: data/WMT19/en_de.nmt/train.hter 13 | 14 | valid: 15 | input: 16 | source: data/WMT19/en_de.nmt/dev.src 17 | target: data/WMT19/en_de.nmt/dev.mt 18 | alignments: data/WMT19/en_de.nmt/dev.src-mt.alignments 19 | post_edit: data/WMT19/en_de.nmt/dev.pe 20 | 21 | output: 22 | source_tags: data/WMT19/en_de.nmt/dev.source_tags 23 | target_tags: data/WMT19/en_de.nmt/dev.tags 24 | sentence_scores: data/WMT19/en_de.nmt/dev.hter 25 | 26 | test: 27 | input: 28 | source: data/WMT19/en_de.nmt/test.src 29 | target: data/WMT19/en_de.nmt/test.mt 30 | alignments: data/WMT19/en_de.nmt/test.src-mt.alignments 31 | post_edit: data/WMT19/en_de.nmt/test.pe 32 | -------------------------------------------------------------------------------- /config/data/wmt20.qe.en_de.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train: 3 | input: 4 | source: data/WMT20/en-de/train/train.src 5 | target: data/WMT20/en-de/train/train.mt 6 | alignments: data/WMT20/en-de/train/train.src-mt.alignments 7 | post_edit: data/WMT20/en-de/train/train.pe 8 | 9 | output: 10 | source_tags: data/WMT20/en-de/train/train.source_tags 11 | target_tags: data/WMT20/en-de/train/train.tags 12 | sentence_scores: data/WMT20/en-de/train/train.hter 13 | 14 | valid: 15 | input: 16 | source: data/WMT20/en-de/dev/dev.src 17 | target: data/WMT20/en-de/dev/dev.mt 18 | alignments: data/WMT20/en-de/dev/dev.src-mt.alignments 19 | post_edit: data/WMT20/en-de/dev/dev.pe 20 | 21 | output: 22 | source_tags: data/WMT20/en-de/dev/dev.source_tags 23 | target_tags: data/WMT20/en-de/dev/dev.tags 24 | sentence_scores: data/WMT20/en-de/dev/dev.hter 25 | 26 | test: 27 | input: 28 | source: data/WMT20/en-de/test-blind/test.src 29 | target: data/WMT20/en-de/test-blind/test.mt 30 | alignments: data/WMT20/en-de/test-blind/test.src-mt.alignments 31 | post_edit: data/WMT20/en-de/test-blind/test-blind.pe 32 | -------------------------------------------------------------------------------- /config/data/wmt20.qe.en_zh.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train: 3 | input: 4 | source: data/WMT20/en-zh/train/train.src 5 | target: data/WMT20/en-zh/train/train.mt 6 | alignments: data/WMT20/en-zh/train/train.src-mt.alignments 7 | post_edit: data/WMT20/en-zh/train/train.pe 8 | 9 | output: 10 | source_tags: data/WMT20/en-zh/train/train.source_tags 11 | target_tags: data/WMT20/en-zh/train/train.tags 12 | sentence_scores: data/WMT20/en-zh/train/train.hter 13 | 14 | valid: 15 | input: 16 | source: data/WMT20/en-zh/dev/dev.src 17 | target: data/WMT20/en-zh/dev/dev.mt 18 | alignments: data/WMT20/en-zh/dev/dev.src-mt.alignments 19 | post_edit: data/WMT20/en-zh/dev/dev.pe 20 | 21 | output: 22 | source_tags: data/WMT20/en-zh/dev/dev.source_tags 23 | target_tags: data/WMT20/en-zh/dev/dev.tags 24 | sentence_scores: data/WMT20/en-zh/dev/dev.hter 25 | 26 | test: 27 | input: 28 | source: data/WMT20/en-zh/test-blind/test.src 29 | target: data/WMT20/en-zh/test-blind/test.mt 30 | alignments: data/WMT20/en-zh/test-blind/test.src-mt.alignments 31 | post_edit: data/WMT20/en-zh/test-blind/test-blind.pe 32 | -------------------------------------------------------------------------------- /config/evaluate.yaml: -------------------------------------------------------------------------------- 1 | gold_files: 2 | source_tags: data/WMT20/sentence_level/en_de.nmt/dev.source_tags 3 | target_tags: data/WMT20/sentence_level/en_de.nmt/dev.tags 4 | sentence_scores: data/WMT20/sentence_level/en_de.nmt/dev.hter 5 | 6 | # Two configuration options: 7 | # 1. (Recommended) Pass the root folders where the predictions live, 8 | # with the standard file names 9 | predicted_dir: 10 | # The evaluation pipeline supports evaluating multiple predictions at the same time 11 | # by passing the folders as a list 12 | - runs/0/4aa891368ff4402fa69a4b081ea2ba62 13 | - runs/0/e9200ada6dc84bfea807b3b02b9c7212 14 | 15 | # 2. Configure each predicted file separately 16 | # predicted_files: 17 | # source_tags: 18 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/source_tags 19 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/source_tags 20 | # # (Recommended) Pass the predicted `targetgaps_tags` file as `target_tags`; 21 | # # the target and gap tags will be separated and evaluated separately as well as jointly 22 | # target_tags: 23 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/targetgaps_tags 24 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/targetgaps_tags 25 | # # Alternatively: 26 | # # target_tags: 27 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/target_tags 28 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/target_tags 29 | # # gap_tags: 30 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/gap_tags 31 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/gap_tags 32 | # sentence_scores: 33 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/sentence_scores 34 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/sentence_scores 35 | -------------------------------------------------------------------------------- /config/nuqe.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | experiment_name: NuQE WMT20 EN-DE 3 | seed: 42 4 | use_mlflow: false 5 | 6 | trainer: 7 | gpus: -1 8 | epochs: 10 9 | 10 | main_metric: 11 | - WMT19_MCC 12 | - PEARSON 13 | 14 | log_interval: 100 15 | checkpoint: 16 | validation_steps: 0.2 17 | early_stop_patience: 10 18 | 19 | defaults: 20 | - data: wmt20.qe.en_de 21 | 22 | system: 23 | class_name: NuQE 24 | 25 | batch_size: 64 26 | num_data_workers: 4 27 | 28 | model: 29 | encoder: 30 | window_size: 3 31 | embeddings: 32 | source: 33 | dim: 50 34 | dropout: 0.5 35 | freeze: false 36 | target: 37 | dim: 50 38 | dropout: 0.5 39 | freeze: false 40 | 41 | decoder: 42 | source: 43 | hidden_sizes: [400, 200, 100, 50] 44 | dropout: 0. 45 | target: 46 | hidden_sizes: [400, 200, 100, 50] 47 | dropout: 0. 48 | 49 | outputs: 50 | word_level: 51 | target: true 52 | gaps: true 53 | source: false 54 | class_weights: 55 | target_tags: 56 | BAD: 3.0 57 | gap_tags: 58 | BAD: 5.0 59 | source_tags: 60 | BAD: 3.0 61 | sentence_level: 62 | hter: true 63 | use_distribution: true 64 | binary: false 65 | 66 | data_processing: 67 | share_input_fields_encoders: false 68 | vocab: 69 | min_frequency: 2 70 | 71 | optimizer: 72 | class_name: adam 73 | learning_rate: 0.001 74 | -------------------------------------------------------------------------------- /config/predict.yaml: -------------------------------------------------------------------------------- 1 | use_gpu: true 2 | 3 | run: 4 | output_dir: predictions 5 | predict_on_data_partition: test 6 | 7 | defaults: 8 | - data: wmt20.qe.en_de 9 | 10 | system: 11 | load: best_model.torch 12 | 13 | batch_size: 16 14 | num_data_workers: 0 15 | 16 | model: 17 | outputs: 18 | word_level: 19 | target: true 20 | gaps: true 21 | source: true 22 | sentence_level: 23 | hter: true 24 | -------------------------------------------------------------------------------- /config/predictor.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | experiment_name: Predictor WMT20 EN-DE 3 | seed: 42 4 | use_mlflow: false 5 | 6 | trainer: 7 | deterministic: true 8 | gpus: -1 9 | epochs: 10 10 | 11 | log_interval: 100 12 | checkpoint: 13 | validation_steps: 0.2 14 | early_stop_patience: 10 15 | 16 | defaults: 17 | - data: wmt20.qe.en_de 18 | 19 | system: 20 | class_name: Predictor 21 | 22 | num_data_workers: 4 23 | batch_size: 24 | train: 32 25 | valid: 32 26 | 27 | model: 28 | encoder: 29 | encode_source: false 30 | hidden_size: 400 31 | rnn_layers: 2 32 | embeddings: 33 | source: 34 | dim: 200 35 | target: 36 | dim: 200 37 | out_embeddings_dim: 200 38 | share_embeddings: false 39 | dropout: 0.5 40 | use_mismatch_features: false 41 | 42 | optimizer: 43 | class_name: adam 44 | learning_rate: 0.001 45 | learning_rate_decay: 0.6 46 | learning_rate_decay_start: 2 47 | 48 | data_processing: 49 | vocab: 50 | min_frequency: 1 51 | max_size: 60_000 52 | -------------------------------------------------------------------------------- /config/predictor_estimator.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | experiment_name: PredictorEstimator WMT20 EN-DE 3 | seed: 42 4 | use_mlflow: false 5 | 6 | trainer: 7 | gpus: -1 8 | epochs: 10 9 | 10 | main_metric: 11 | - WMT19_MCC 12 | - PEARSON 13 | 14 | log_interval: 100 15 | checkpoint: 16 | validation_steps: 0.2 17 | early_stop_patience: 10 18 | 19 | defaults: 20 | - data: wmt20.qe.en_de 21 | 22 | system: 23 | class_name: PredictorEstimator 24 | 25 | batch_size: 32 26 | num_data_workers: 4 27 | 28 | load_encoder: best_model.torch 29 | 30 | model: 31 | encoder: 32 | encode_source: false 33 | hidden_size: 400 34 | rnn_layers: 2 35 | embeddings: 36 | source: 37 | dim: 200 38 | target: 39 | dim: 200 40 | out_embeddings_dim: 200 41 | share_embeddings: false 42 | dropout: 0.5 43 | use_mismatch_features: false 44 | 45 | decoder: 46 | hidden_size: 125 47 | rnn_layers: 1 48 | use_mlp: true 49 | dropout: 0.0 50 | 51 | outputs: 52 | word_level: 53 | target: true 54 | gaps: false 55 | source: false 56 | class_weights: 57 | target_tags: 58 | BAD: 5.0 59 | gap_tags: 60 | BAD: 5.0 61 | source_tags: 62 | BAD: 3.0 63 | sentence_level: 64 | hter: true 65 | use_distribution: true 66 | binary: false 67 | 68 | tlm_outputs: 69 | fine_tune: true 70 | 71 | optimizer: 72 | class_name: adam 73 | learning_rate: 0.001 74 | learning_rate_decay: 0.6 75 | learning_rate_decay_start: 2 76 | 77 | data_processing: 78 | vocab: 79 | min_frequency: 1 80 | max_size: 60_000 81 | -------------------------------------------------------------------------------- /config/search.yaml: -------------------------------------------------------------------------------- 1 | base_config: config/bert.yaml 2 | 3 | options: 4 | search_method: random 5 | # Search the model architecture 6 | # You can specify a list of values... 7 | hidden_size: 8 | - 768 9 | - 324 10 | # ...or you can specify a discrete range... 11 | bottleneck_size: 12 | lower: 100 13 | upper: 500 14 | step: 100 15 | search_mlp: true 16 | # Search optimizer 17 | # ...or you can specify a continuous interval. 18 | learning_rate: 19 | lower: 1e-7 20 | upper: 1e-5 21 | distribution: loguniform # recommended for the learning rate 22 | # Search weights for the tag loss 23 | class_weights: 24 | target_tags: 25 | lower: 1 26 | upper: 10 27 | step: 1 28 | gap_tags: 29 | lower: 1 30 | upper: 20 31 | step: 1 32 | source_tags: null 33 | # Search the sentence level objective 34 | search_hter: true 35 | sentence_loss_weight: 36 | lower: 1 37 | upper: 10 38 | -------------------------------------------------------------------------------- /config/xlm.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | experiment_name: XLM WMT20 EN-DE 3 | seed: 42 4 | use_mlflow: false 5 | 6 | trainer: 7 | deterministic: true 8 | gpus: -1 9 | epochs: 10 10 | 11 | main_metric: 12 | - WMT19_MCC 13 | - PEARSON 14 | 15 | gradient_max_norm: 1. 16 | gradient_accumulation_steps: 1 17 | 18 | amp_level: O2 19 | precision: 16 20 | 21 | log_interval: 100 22 | checkpoint: 23 | validation_steps: 0.2 24 | early_stop_patience: 10 25 | 26 | defaults: 27 | - data: wmt20.qe.en_de 28 | 29 | system: 30 | class_name: XLM 31 | 32 | batch_size: 8 33 | num_data_workers: 4 34 | 35 | model: 36 | encoder: 37 | model_name: xlm-mlm-tlm-xnli15-1024 38 | interleave_input: false 39 | freeze: false 40 | use_mlp: false 41 | 42 | decoder: 43 | hidden_size: 768 44 | bottleneck_size: 768 45 | dropout: 0.1 46 | 47 | outputs: 48 | word_level: 49 | target: true 50 | gaps: true 51 | source: true 52 | class_weights: 53 | target_tags: 54 | BAD: 3.0 55 | gap_tags: 56 | BAD: 5.0 57 | source_tags: 58 | BAD: 3.0 59 | sentence_level: 60 | hter: true 61 | use_distribution: false 62 | binary: false 63 | n_layers_output: 2 64 | sentence_loss_weight: 2 65 | 66 | tlm_outputs: 67 | fine_tune: false 68 | 69 | optimizer: 70 | class_name: adamw 71 | learning_rate: 1e-05 72 | warmup_steps: 0.1 73 | training_steps: 12000 74 | 75 | data_processing: 76 | share_input_fields_encoders: true 77 | -------------------------------------------------------------------------------- /config/xlmroberta.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | experiment_name: XLM-Roberta WMT20 EN-DE 3 | seed: 42 4 | use_mlflow: false 5 | 6 | trainer: 7 | deterministic: true 8 | gpus: -1 9 | epochs: 10 10 | 11 | main_metric: 12 | - WMT19_MCC 13 | - PEARSON 14 | 15 | gradient_max_norm: 1. 16 | gradient_accumulation_steps: 1 17 | 18 | amp_level: O2 19 | precision: 16 20 | 21 | log_interval: 100 22 | checkpoint: 23 | validation_steps: 0.2 24 | early_stop_patience: 10 25 | 26 | defaults: 27 | - data: wmt20.qe.en_de 28 | 29 | system: 30 | class_name: XLMRoberta 31 | 32 | batch_size: 8 33 | num_data_workers: 4 34 | 35 | model: 36 | encoder: 37 | model_name: xlm-roberta-base 38 | interleave_input: false 39 | freeze: false 40 | use_mlp: false 41 | pooling: mixed 42 | freeze_for_number_of_steps: 1000 43 | 44 | decoder: 45 | hidden_size: 768 46 | bottleneck_size: 768 47 | dropout: 0.1 48 | 49 | outputs: 50 | word_level: 51 | target: true 52 | gaps: true 53 | source: true 54 | class_weights: 55 | target_tags: 56 | BAD: 3.0 57 | gap_tags: 58 | BAD: 5.0 59 | source_tags: 60 | BAD: 3.0 61 | sentence_level: 62 | hter: true 63 | use_distribution: false 64 | binary: false 65 | n_layers_output: 2 66 | sentence_loss_weight: 1 67 | 68 | tlm_outputs: 69 | fine_tune: false 70 | 71 | optimizer: 72 | class_name: adamw 73 | learning_rate: 1e-05 74 | warmup_steps: 0.1 75 | training_steps: 12000 76 | 77 | data_processing: 78 | share_input_fields_encoders: true 79 | -------------------------------------------------------------------------------- /docs/_static/img/openkiwi-logo-icon-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /docs/_static/img/openkiwi-logo-icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unbabel/OpenKiwi/07d7cfed880457d95daf189dd8282e5b02ac2954/docs/_static/img/openkiwi-logo-icon.ico -------------------------------------------------------------------------------- /docs/_static/img/openkiwi-logo-icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 50 | 51 | 52 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /docs/configuration/evaluate.rst: -------------------------------------------------------------------------------- 1 | Evaluation configuration 2 | ======================== 3 | 4 | An example configuration file is provided in ``config/evaluate.yaml``: 5 | 6 | .. literalinclude:: ../../config/evaluate.yaml 7 | :language: yaml 8 | 9 | 10 | Configuration class 11 | ------------------- 12 | 13 | Full API reference: :class:`kiwi.lib.evaluate.Configuration` 14 | 15 | .. autoclass:: kiwi.lib.evaluate.Configuration 16 | :noindex: 17 | -------------------------------------------------------------------------------- /docs/configuration/index.rst: -------------------------------------------------------------------------------- 1 | .. _configuration: 2 | 3 | Configuration 4 | ============= 5 | 6 | .. toctree:: 7 | :maxdepth: 5 8 | :hidden: 9 | 10 | train 11 | predict 12 | evaluate 13 | search 14 | 15 | 16 | Kiwi can be configured essentially by using dictionaries. To persist configuration 17 | dictionaries, we recommend using YAML. You can find standard configuration files 18 | in ``config/``. 19 | 20 | 21 | CLI overrides 22 | ------------- 23 | 24 | To run Kiwi with a configuration file but overriding a value, use the format 25 | ``key=value``, where nested keys can be encoded by using a dot notation. 26 | For example:: 27 | 28 | kiwi train config/bert.yaml trainer.gpus=0 system.batch_size=16 29 | 30 | 31 | Configuration Composing 32 | ----------------------- 33 | 34 | Kiwi uses Hydra/OmegaConf to compose configuration coming from different places. 35 | This makes it possible to split configuration across multiple files. 36 | 37 | In most files in ``config/``, like ``config/predict.yaml``, you'll notice this: 38 | 39 | .. code-block:: yaml 40 | 41 | defaults: 42 | - data: wmt19.qe.en_de 43 | 44 | This means the file ``config/data/wmt19.qe.en_de.yaml`` will be loaded into the 45 | configuration found in ``config/predict.yaml``. **Notice** that ``wmt19.qe.en_de.yaml`` 46 | must use fully qualified keys levels, that is, the full nesting of keys. 47 | 48 | The nice use case for this is allowing dynamically changing parts of the configuration. 49 | For example, we can use:: 50 | 51 | kiwi train config/bert.yaml data=unbabel.qe.en_pt 52 | 53 | to use a different dataset (where ``config/data/unbabel.qe.en_pt.yaml`` contains the 54 | configuration for the data files). 55 | -------------------------------------------------------------------------------- /docs/configuration/predict.rst: -------------------------------------------------------------------------------- 1 | Prediction configuration 2 | ======================== 3 | 4 | An example configuration file is provided in ``config/predict.yaml``: 5 | 6 | .. literalinclude:: ../../config/predict.yaml 7 | :language: yaml 8 | 9 | 10 | Configuration class 11 | ------------------- 12 | 13 | Full API reference: :class:`kiwi.lib.predict.Configuration` 14 | 15 | .. autoclass:: kiwi.lib.predict.Configuration 16 | :noindex: 17 | -------------------------------------------------------------------------------- /docs/configuration/search.rst: -------------------------------------------------------------------------------- 1 | Search configuration 2 | ====================== 3 | 4 | There is an example search configuration file available in ``config/`` for the ``BERT`` model: 5 | 6 | .. literalinclude:: ../../config/search.yaml 7 | :language: yaml 8 | 9 | 10 | Configuration class 11 | ------------------- 12 | 13 | Full API reference: :class:`kiwi.lib.search.Configuration` 14 | 15 | 16 | .. autosummary: 17 | :toctree: stubs 18 | 19 | .. kiwi.lib.search.Configuration 20 | kiwi.lib.search.SearchOptions 21 | kiwi.lib.search.ClassWeightsConfig 22 | kiwi.lib.search.RangeConfig 23 | 24 | 25 | .. autoclass:: kiwi.lib.search.Configuration 26 | :noindex: 27 | -------------------------------------------------------------------------------- /docs/configuration/train.rst: -------------------------------------------------------------------------------- 1 | Training configuration 2 | ====================== 3 | 4 | There are training configuration files available in ``config/`` for all supported models. 5 | 6 | As an example, here is the configuration for the ``Bert`` model: 7 | 8 | .. literalinclude:: ../../config/bert.yaml 9 | :language: yaml 10 | 11 | 12 | Configuration class 13 | ------------------- 14 | 15 | Full API reference: :class:`kiwi.lib.train.Configuration` 16 | 17 | 18 | .. autosummary: 19 | :toctree: stubs 20 | 21 | .. kiwi.lib.train.Configuration 22 | kiwi.lib.train.RunConfig 23 | kiwi.lib.train.TrainerConfig 24 | kiwi.data.datasets.wmt_qe_dataset.WMTQEDataset.Config 25 | kiwi.systems.qe_system.QESystem.Config 26 | 27 | 28 | .. autoclass:: kiwi.lib.train.Configuration 29 | :noindex: 30 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | UnbabelKiwi's documentation 2 | =========================== 3 | 4 | .. image:: _static/img/openkiwi-logo-horizontal.svg 5 | :width: 400 6 | :alt: OpenKiwi by Unbabel 7 | 8 | ---- 9 | 10 | .. mdinclude:: ../README.md 11 | :start-line: 3 12 | 13 | Documentation 14 | ------------- 15 | 16 | .. toctree:: 17 | :maxdepth: 5 18 | 19 | installation 20 | usage 21 | configuration/index 22 | systems/index 23 | autoapi/index 24 | 25 | .. code/index 26 | 27 | Indices and tables 28 | ------------------ 29 | 30 | * :ref:`genindex` 31 | * :ref:`modindex` 32 | * :ref:`search` 33 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | ============ 5 | 6 | OpenKiwi can be installed in two ways, depending on how it's going to be used. 7 | 8 | As a library 9 | ------------ 10 | 11 | Simply run:: 12 | 13 | pip install openkiwi 14 | 15 | You can now:: 16 | 17 | import kiwi 18 | 19 | inside your project or run in the command line:: 20 | 21 | kiwi 22 | 23 | 24 | As a local package 25 | ------------------ 26 | 27 | OpenKiwi's configuration is in ``pyproject.toml`` (as defined by PEP-518). 28 | We use `Poetry `_ as the build system 29 | and the dependency manager. All dependencies are specified in that file. 30 | 31 | Install Poetry via the recommended way:: 32 | 33 | curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python 34 | 35 | It's also possible to use pip (but not recommended as it might mess up local dependencies):: 36 | 37 | pip install poetry 38 | 39 | In your virtualenv just run:: 40 | 41 | poetry install 42 | 43 | to install all dependencies. 44 | 45 | That's it! Now, running:: 46 | 47 | python kiwi -h 48 | 49 | or:: 50 | 51 | kiwi -h 52 | 53 | should show you a help message. 54 | 55 | 56 | MLflow integration 57 | ------------------ 58 | 59 | **Optionally**, to take advantage of our `MLflow `_ integration, install Kiwi with:: 60 | 61 | pip install openkiwi[mlflow] 62 | 63 | 64 | **Or**:: 65 | 66 | poetry install -E mlflow 67 | 68 | 69 | Hyperparameter search with Optuna 70 | --------------------------------- 71 | 72 | **Optionally**, to use the hyperparameter search pipeline with `Optuna `_, 73 | install Kiwi with:: 74 | 75 | pip install openkiwi[search] 76 | 77 | 78 | **Or**:: 79 | 80 | poetry install -E search 81 | -------------------------------------------------------------------------------- /docs/systems/index.rst: -------------------------------------------------------------------------------- 1 | Systems 2 | ======= 3 | 4 | There are two types of systems: **QE** and **TLM**. The first is the regular one used 5 | for Quality Estimation. The second, which stands for *Translation Language Model*, is 6 | used for pre-training the *encoder* component of a QE system. 7 | 8 | 9 | All systems, regardless of them being **QE** or **TLM** systems are constructed on top 10 | of PytorchLightning's (PTL) `LightningModule` class. This means they respect a certain design 11 | philosophy that can be consulted in PTL's documentation_. 12 | 13 | .. _documentation: https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html 14 | 15 | Furthermore, all of our QE systems share a similar architecture. They are composed of 16 | 17 | - Encoder 18 | - Decoder 19 | - Output 20 | - (Optionally) TLM Output 21 | 22 | TLM systems on the other hand are only composed of the Encoder + TLM Output and their 23 | goal is to pre-train the encoder so it can be plugged into a QE system. 24 | 25 | The systems are divided into the 3 blocks that have the major responsabilities in both 26 | word and sentence level QE tasks: 27 | 28 | Encoder: Embedding and creating features to be used for downstream tasks. i.e. Predictor, 29 | BERT, etc 30 | 31 | Decoder: Responsible for learning feature transformations better suited for the 32 | downstream task. i.e. MLP, LSTM, etc 33 | 34 | Output: Simple feedforwards that take decoder features and transform them into the 35 | prediction required by the downstream task. Something in the same line as the common 36 | "classification heads" being used with transformers. 37 | 38 | TLM Output: A simple output layer that trains for the specific TLM objective. It can be 39 | useful to continue finetuning the predictor during training of the complete QE system. 40 | 41 | 42 | QE --- :mod:`kiwi.systems.qe_system` 43 | ------------------------------------ 44 | 45 | All QE systems inherit from :class:`kiwi.systems.qe_system.QESystem`. 46 | 47 | Use ``kiwi train`` to train these systems. 48 | 49 | Currently available are: 50 | 51 | +--------------------------------------------------------------+ 52 | | :class:`kiwi.systems.nuqe.NuQE` | 53 | +--------------------------------------------------------------+ 54 | | :class:`kiwi.systems.predictor_estimator.PredictorEstimator` | 55 | +--------------------------------------------------------------+ 56 | | :class:`kiwi.systems.bert.Bert` | 57 | +--------------------------------------------------------------+ 58 | | :class:`kiwi.systems.xlm.XLM` | 59 | +--------------------------------------------------------------+ 60 | | :class:`kiwi.systems.xlmroberta.XLMRoberta` | 61 | +--------------------------------------------------------------+ 62 | 63 | 64 | TLM --- :mod:`kiwi.systems.tlm_system` 65 | -------------------------------------- 66 | 67 | All TLM systems inherit from :class:`kiwi.systems.tlm_system.TLMSystem`. 68 | 69 | Use ``kiwi pretrain`` to train these systems. These systems can then be used as the 70 | encoder part in QE systems by using the `load_encoder` flag. 71 | 72 | Currently available are: 73 | 74 | +-------------------------------------------+ 75 | | :class:`kiwi.systems.predictor.Predictor` | 76 | +-------------------------------------------+ 77 | -------------------------------------------------------------------------------- /kiwi/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from kiwi.lib.train import train_from_file # NOQA isort:skip 18 | from kiwi.lib.predict import load_system # NOQA isort:skip 19 | 20 | __version__ = '2.1.0' 21 | __copyright__ = '2019-2020 Unbabel. All rights reserved.' 22 | -------------------------------------------------------------------------------- /kiwi/__main__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import kiwi.cli 18 | 19 | 20 | def main(): 21 | return kiwi.cli.cli() 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /kiwi/assets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unbabel/OpenKiwi/07d7cfed880457d95daf189dd8282e5b02ac2954/kiwi/assets/__init__.py -------------------------------------------------------------------------------- /kiwi/assets/config/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def file_path(name): 5 | package_directory = Path(__file__).parent 6 | return package_directory / name 7 | -------------------------------------------------------------------------------- /kiwi/assets/config/evaluate.yaml: -------------------------------------------------------------------------------- 1 | gold_files: 2 | source_tags: data/WMT20/sentence_level/en_de.nmt/dev.source_tags 3 | target_tags: data/WMT20/sentence_level/en_de.nmt/dev.tags 4 | sentence_scores: data/WMT20/sentence_level/en_de.nmt/dev.hter 5 | 6 | # Two configuration options: 7 | # 1. (Recommended) Pass the root folders where the predictions live, 8 | # with the standard file names 9 | predicted_dir: 10 | # The evaluation pipeline supports evaluating multiple predictions at the same time 11 | # by passing the folders as a list 12 | - runs/0/4aa891368ff4402fa69a4b081ea2ba62 13 | - runs/0/e9200ada6dc84bfea807b3b02b9c7212 14 | 15 | # 2. Configure each predicted file separately 16 | # predicted_files: 17 | # source_tags: 18 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/source_tags 19 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/source_tags 20 | # # (Recommended) Pass the predicted `targetgaps_tags` file as `target_tags`; 21 | # # the target and gap tags will be separated and evaluated separately as well as jointly 22 | # target_tags: 23 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/targetgaps_tags 24 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/targetgaps_tags 25 | # # Alternatively: 26 | # # target_tags: 27 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/target_tags 28 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/target_tags 29 | # # gap_tags: 30 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/gap_tags 31 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/gap_tags 32 | # sentence_scores: 33 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/sentence_scores 34 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/sentence_scores 35 | -------------------------------------------------------------------------------- /kiwi/assets/config/predict.yaml: -------------------------------------------------------------------------------- 1 | use_gpu: true 2 | 3 | run: 4 | output_dir: predictions 5 | predict_on_data_partition: test 6 | 7 | defaults: 8 | - data: wmt20.qe.en_de 9 | 10 | system: 11 | load: best_model.torch 12 | 13 | batch_size: 16 14 | num_data_workers: 0 15 | 16 | model: 17 | outputs: 18 | word_level: 19 | target: true 20 | gaps: true 21 | source: true 22 | sentence_level: 23 | hter: true 24 | -------------------------------------------------------------------------------- /kiwi/assets/config/pretrain.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | experiment_name: Predictor WMT20 EN-DE 3 | seed: 42 4 | use_mlflow: false 5 | 6 | trainer: 7 | deterministic: true 8 | gpus: -1 9 | epochs: 10 10 | 11 | log_interval: 100 12 | checkpoint: 13 | validation_steps: 0.2 14 | early_stop_patience: 10 15 | 16 | defaults: 17 | - data: wmt20.qe.en_de 18 | 19 | system: 20 | class_name: Predictor 21 | 22 | num_data_workers: 4 23 | batch_size: 24 | train: 32 25 | valid: 32 26 | 27 | model: 28 | encoder: 29 | hidden_size: 400 30 | rnn_layers: 2 31 | embeddings: 32 | source: 33 | dim: 200 34 | target: 35 | dim: 200 36 | out_embeddings_dim: 200 37 | share_embeddings: false 38 | dropout: 0.5 39 | use_mismatch_features: false 40 | 41 | optimizer: 42 | class_name: adam 43 | learning_rate: 0.001 44 | learning_rate_decay: 0.6 45 | learning_rate_decay_start: 2 46 | 47 | data_processing: 48 | vocab: 49 | min_frequency: 1 50 | max_size: 60_000 51 | -------------------------------------------------------------------------------- /kiwi/assets/config/search.yaml: -------------------------------------------------------------------------------- 1 | base_config: config/bert.yaml 2 | 3 | options: 4 | search_method: random 5 | # Search the model architecture 6 | # You can specify a list of values... 7 | hidden_size: 8 | - 768 9 | - 324 10 | # ...or you can specify a discrete range... 11 | bottleneck_size: 12 | lower: 100 13 | upper: 500 14 | step: 100 15 | search_mlp: true 16 | # Search optimizer 17 | # ...or you can specify a continuous interval. 18 | learning_rate: 19 | lower: 1e-7 20 | upper: 1e-5 21 | distribution: loguniform # recommended for the learning rate 22 | # Search weights for the tag loss 23 | class_weights: 24 | target_tags: 25 | lower: 1 26 | upper: 10 27 | step: 1 28 | gap_tags: 29 | lower: 1 30 | upper: 20 31 | step: 1 32 | source_tags: null 33 | # Search the sentence level objective 34 | search_hter: true 35 | sentence_loss_weight: 36 | lower: 1 37 | upper: 10 38 | -------------------------------------------------------------------------------- /kiwi/assets/config/train.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | ################################################### 3 | # Generic configurations options related to 4 | # handling of experiments 5 | experiment_name: BERT WMT20 EN-DE 6 | seed: 42 7 | use_mlflow: false 8 | 9 | trainer: 10 | ################################################### 11 | # Generic options related to the training process 12 | # that apply to all models 13 | deterministic: true 14 | gpus: -1 15 | epochs: 10 16 | 17 | main_metric: 18 | - WMT19_MCC 19 | - PEARSON 20 | 21 | gradient_max_norm: 1. 22 | gradient_accumulation_steps: 1 23 | 24 | # Control the model precision, see 25 | # https://pytorch-lightning.readthedocs.io/en/stable/amp.html 26 | # for more info on the configuration options 27 | # Fast mixed precision by default: 28 | amp_level: O2 29 | precision: 16 30 | 31 | log_interval: 100 32 | checkpoint: 33 | validation_steps: 0.2 34 | early_stop_patience: 10 35 | 36 | 37 | defaults: 38 | ################################################### 39 | # Example of composition of configuration files 40 | # this config is sourced from /config/data/wmt20.qe.en_de.yaml 41 | - data: wmt20.qe.en_de 42 | 43 | system: 44 | ################################################### 45 | # System configs are responsible for all the system 46 | # specific configurations. From model settings to 47 | # optimizers and specific processing options. 48 | 49 | # All configs must have either `class_name` or `load` 50 | class_name: Bert 51 | 52 | batch_size: 8 53 | num_data_workers: 4 54 | 55 | model: 56 | ################################################ 57 | # Modeling options. These can change a lot about 58 | # the architecture of the system. With many configuration 59 | # options adding (or removing) layers. 60 | encoder: 61 | model_name: bert-base-multilingual-cased 62 | use_mlp: false 63 | freeze: false 64 | 65 | decoder: 66 | hidden_size: 768 67 | bottleneck_size: 768 68 | dropout: 0.1 69 | 70 | outputs: 71 | #################################################### 72 | # Output options configure the downstream tasks the 73 | # model will be trained on by adding specific layers 74 | # responsible for transforming decoder features into 75 | # predictions. 76 | word_level: 77 | target: true 78 | gaps: true 79 | source: true 80 | class_weights: 81 | target_tags: 82 | BAD: 3.0 83 | gap_tags: 84 | BAD: 5.0 85 | source_tags: 86 | BAD: 3.0 87 | sentence_level: 88 | hter: true 89 | use_distribution: false 90 | binary: false 91 | n_layers_output: 2 92 | sentence_loss_weight: 1 93 | 94 | tlm_outputs: 95 | fine_tune: false 96 | 97 | optimizer: 98 | class_name: adamw 99 | learning_rate: 1e-5 100 | warmup_steps: 0.1 101 | training_steps: 12000 102 | 103 | data_processing: 104 | share_input_fields_encoders: true 105 | -------------------------------------------------------------------------------- /kiwi/cli.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | """ 18 | Kiwi runner 19 | ~~~~~~~~~~~ 20 | 21 | Quality Estimation toolkit. 22 | 23 | Invoke as ``kiwi PIPELINE``. 24 | 25 | Usage: 26 | kiwi [options] (train|pretrain|predict|evaluate|search) CONFIG_FILE [OVERWRITES ...] 27 | kiwi (train|pretrain|predict|evaluate|search) --example 28 | kiwi (-h | --help | --version) 29 | 30 | 31 | Pipelines: 32 | train Train a QE model 33 | pretrain Pretrain a TLM model to be used as an encoder for a QE model 34 | predict Use a pre-trained model for prediction 35 | evaluate Evaluate a model's predictions using popular metrics 36 | search Search training hyper-parameters for a QE model 37 | 38 | Disabled pipelines: 39 | jackknife Jackknife training data with model 40 | 41 | Arguments: 42 | CONFIG_FILE configuration file to use (e.g., config/nuqe.yaml) 43 | OVERWRITES key=value to overwrite values in CONFIG_FILE; use ``key.subkey`` 44 | for nested keys. 45 | 46 | Options: 47 | -v --verbose log debug messages 48 | -q --quiet log only warning and error messages 49 | -h --help show this help message and exit 50 | --version show version and exit 51 | --example print an example configuration file 52 | 53 | """ 54 | import sys 55 | 56 | from docopt import docopt 57 | 58 | from kiwi import __version__ 59 | from kiwi.assets import config 60 | from kiwi.lib import evaluate, predict, pretrain, search, train 61 | from kiwi.lib.utils import arguments_to_configuration 62 | 63 | 64 | def handle_example(arguments, caller): 65 | 66 | if arguments.get('--example'): 67 | conf_file = f'{caller}.yaml' 68 | print(config.file_path(conf_file).read_text()) 69 | print( 70 | f'# Save the above in a file called {conf_file} and then run:\n' 71 | f'# {" ".join(sys.argv[:-1])} {conf_file}' 72 | ) 73 | sys.exit(0) 74 | 75 | 76 | def cli(): 77 | arguments = docopt( 78 | __doc__, argv=sys.argv[1:], help=True, version=__version__, options_first=False 79 | ) 80 | 81 | if arguments['train']: 82 | handle_example(arguments, 'train') 83 | config_dict = arguments_to_configuration(arguments) 84 | train.train_from_configuration(config_dict) 85 | if arguments['predict']: 86 | handle_example(arguments, 'predict') 87 | config_dict = arguments_to_configuration(arguments) 88 | predict.predict_from_configuration(config_dict) 89 | if arguments['pretrain']: 90 | handle_example(arguments, 'pretrain') 91 | config_dict = arguments_to_configuration(arguments) 92 | pretrain.pretrain_from_configuration(config_dict) 93 | if arguments['evaluate']: 94 | handle_example(arguments, 'evaluate') 95 | config_dict = arguments_to_configuration(arguments) 96 | evaluate.evaluate_from_configuration(config_dict) 97 | if arguments['search']: 98 | handle_example(arguments, 'search') 99 | config_dict = arguments_to_configuration(arguments) 100 | search.search_from_configuration(config_dict) 101 | # Meta Pipelines 102 | # if options.pipeline == 'jackknife': 103 | # jackknife.main(extra_args) 104 | 105 | 106 | if __name__ == '__main__': # pragma: no cover 107 | cli() # pragma: no cover 108 | -------------------------------------------------------------------------------- /kiwi/constants.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | 18 | # lowercased special tokens 19 | UNK = '' 20 | PAD = '' 21 | START = '' 22 | STOP = '' 23 | UNALIGNED = '' 24 | 25 | # binary labels 26 | OK = 'OK' 27 | BAD = 'BAD' 28 | LABELS = [OK, BAD] 29 | 30 | SOURCE = 'source' 31 | TARGET = 'target' 32 | PE = 'pe' 33 | TARGET_TAGS = 'target_tags' 34 | SOURCE_TAGS = 'source_tags' 35 | GAP_TAGS = 'gap_tags' 36 | TARGETGAPS_TAGS = 'targetgaps_tags' 37 | 38 | SOURCE_LOGITS = 'source_logits' 39 | TARGET_LOGITS = 'target_logits' 40 | TARGET_SENTENCE = 'target_sentence' 41 | PE_LOGITS = 'pe_logits' 42 | 43 | SENTENCE_SCORES = 'sentence_scores' 44 | BINARY = 'binary' 45 | 46 | ALIGNMENTS = 'alignments' 47 | SOURCE_POS = 'source_pos' 48 | TARGET_POS = 'target_pos' 49 | 50 | # Constants for model output names 51 | LOSS = 'loss' 52 | 53 | # Standard Names for saving files 54 | VOCAB = 'vocab' 55 | CONFIG = 'config' 56 | STATE_DICT = 'state_dict' 57 | -------------------------------------------------------------------------------- /kiwi/data/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/data/batch.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | import torchnlp.utils 21 | from torchnlp.encoders.text import BatchedSequences 22 | 23 | 24 | @dataclass 25 | class BatchedSentence: 26 | tensor: torch.Tensor 27 | lengths: torch.Tensor 28 | bounds: torch.Tensor 29 | bounds_lengths: torch.Tensor 30 | strict_masks: torch.Tensor 31 | number_of_tokens: torch.Tensor 32 | 33 | def pin_memory(self): 34 | self.tensor = self.tensor.pin_memory() 35 | self.lengths = self.lengths.pin_memory() 36 | self.bounds = self.bounds.pin_memory() 37 | self.bounds_lengths = self.bounds_lengths.pin_memory() 38 | self.strict_masks = self.strict_masks.pin_memory() 39 | self.number_of_tokens = self.number_of_tokens.pin_memory() 40 | return self 41 | 42 | def to(self, *args, **kwargs): 43 | self.tensor = self.tensor.to(*args, **kwargs) 44 | self.lengths = self.lengths.to(*args, **kwargs) 45 | self.bounds = self.bounds.to(*args, **kwargs) 46 | self.bounds_lengths = self.bounds_lengths.to(*args, **kwargs) 47 | self.strict_masks = self.strict_masks.to(*args, **kwargs) 48 | self.number_of_tokens = self.number_of_tokens.to(*args, **kwargs) 49 | return self 50 | 51 | 52 | class MultiFieldBatch(dict): 53 | def __init__(self, batch: dict): 54 | super().__init__() 55 | self.update(batch) 56 | 57 | def pin_memory(self): 58 | for field, data in self.items(): 59 | if isinstance(data, BatchedSequences): 60 | tensor = data.tensor.pin_memory() 61 | lengths = data.lengths.pin_memory() 62 | self[field] = BatchedSequences(tensor=tensor, lengths=lengths) 63 | else: 64 | self[field] = data.pin_memory() 65 | return self 66 | 67 | def to(self, *args, **kwargs): 68 | for field, data in self.items(): 69 | self[field] = data.to(*args, **kwargs) 70 | return self 71 | 72 | 73 | def tensors_to(tensors, *args, **kwargs): 74 | if isinstance(tensors, (MultiFieldBatch, BatchedSentence)): 75 | return tensors.to(*args, **kwargs) 76 | elif isinstance(tensors, dict): 77 | return {k: tensors_to(v, *args, **kwargs) for k, v in tensors.items()} 78 | else: 79 | return torchnlp.utils.tensors_to(tensors, *args, **kwargs) 80 | -------------------------------------------------------------------------------- /kiwi/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/data/encoders/base.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | class DataEncoders: 18 | def __init__(self): 19 | pass 20 | 21 | @property 22 | def vocabularies(self): 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /kiwi/data/tokenizers.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | def tokenize(sentence): 18 | """Implement your own tokenize procedure.""" 19 | return sentence.strip().split() 20 | 21 | 22 | def detokenize(tokens): 23 | return ' '.join(tokens) 24 | 25 | 26 | def align_tokenize(s): 27 | """Return a list of pair of integers for each sentence.""" 28 | return [tuple(map(int, x.split('-'))) for x in s.strip().split()] 29 | 30 | 31 | def bert_tokenizer(sentence): 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /kiwi/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/lib/pretrain.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | 19 | from kiwi.lib.train import Configuration as TrainConfig 20 | from kiwi.lib.train import TrainRunInfo, run 21 | from kiwi.lib.utils import file_to_configuration 22 | from kiwi.systems.tlm_system import TLMSystem 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class Configuration(TrainConfig): 28 | system: TLMSystem.Config 29 | 30 | 31 | def pretrain_from_file(filename) -> TrainRunInfo: 32 | """Load options from a config file and call the pretraining procedure. 33 | 34 | Arguments: 35 | filename: of the configuration file. 36 | 37 | Return: 38 | object with training information. 39 | """ 40 | config = file_to_configuration(filename) 41 | return pretrain_from_configuration(config) 42 | 43 | 44 | def pretrain_from_configuration(configuration_dict) -> TrainRunInfo: 45 | """Run the entire training pipeline using the configuration options received. 46 | 47 | Arguments: 48 | configuration_dict: dictionary with config options. 49 | 50 | Return: 51 | object with training information. 52 | """ 53 | config = Configuration(**configuration_dict) 54 | train_info = run(config, TLMSystem) 55 | return train_info 56 | -------------------------------------------------------------------------------- /kiwi/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from kiwi.metrics import metrics 18 | 19 | F1MultMetric = metrics.F1MultMetric 20 | ExpectedErrorMetric = metrics.ExpectedErrorMetric 21 | PerplexityMetric = metrics.PerplexityMetric 22 | CorrectMetric = metrics.CorrectMetric 23 | RMSEMetric = metrics.RMSEMetric 24 | PearsonMetric = metrics.PearsonMetric 25 | SpearmanMetric = metrics.SpearmanMetric 26 | MatthewsMetric = metrics.MatthewsMetric 27 | # ThresholdCalibrationMetric = metrics.ThresholdCalibrationMetric 28 | -------------------------------------------------------------------------------- /kiwi/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/modules/common/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/modules/common/activations.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import math 18 | 19 | import torch 20 | 21 | 22 | def gelu(x): 23 | """gelu activation function copied from pytorch-pretrained-BERT.""" 24 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 25 | 26 | 27 | def swish(x): 28 | """swusg activation function copied from pytorch-pretrained-BERT.""" 29 | return x * torch.sigmoid(x) 30 | -------------------------------------------------------------------------------- /kiwi/modules/common/attention.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from kiwi.utils.tensors import unsqueeze_as 22 | 23 | 24 | class Attention(nn.Module): 25 | """Generic Attention Implementation. 26 | 27 | 1. Use `query` and `keys` to compute scores (energies) 28 | 2. Apply softmax to get attention probabilities 29 | 3. Perform a dot product between `values` and probabilites (outputs) 30 | 31 | Arguments: 32 | scorer (kiwi.modules.common.Scorer): a scorer object 33 | dropout (float): dropout rate after softmax (default: 0.) 34 | """ 35 | 36 | def __init__(self, scorer, dropout=0): 37 | super().__init__() 38 | self.scorer = scorer 39 | self.dropout = nn.Dropout(p=dropout) 40 | self.NEG_INF = -1e9 # for masking attention scores before softmax 41 | 42 | def forward(self, query, keys, values=None, mask=None): 43 | """Compute the attention between query, keys and values. 44 | 45 | Arguments: 46 | query (torch.Tensor): set of query vectors with shape of 47 | (batch_size, ..., target_len, hidden_size) 48 | keys (torch.Tensor): set of keys vectors with shape of 49 | (batch_size, ..., source_len, hidden_size) 50 | values (torch.Tensor, optional): set of values vectors with 51 | shape of: (batch_size, ..., source_len, hidden_size). 52 | If None, keys are treated as values. Default: None 53 | mask (torch.ByteTensor, optional): Tensor representing valid 54 | positions. If None, all positions are considered valid. 55 | Shape of (batch_size, target_len) 56 | 57 | Return: 58 | torch.Tensor: combination of values and attention probabilities. 59 | Shape of (batch_size, ..., target_len, hidden_size) 60 | torch.Tensor: attention probabilities between query and keys. 61 | Shape of (batch_size, ..., target_len, source_len) 62 | """ 63 | if values is None: 64 | values = keys 65 | 66 | # get scores (aka energies) 67 | scores = self.scorer(query, keys) 68 | 69 | # mask out scores to infinity before softmax 70 | if mask is not None: 71 | # broadcast in keys' timestep dim many times as needed 72 | mask = unsqueeze_as(mask, scores, dim=-2) 73 | scores = scores.masked_fill(mask == 0, self.NEG_INF) 74 | 75 | # apply softmax to get probs 76 | p_attn = F.softmax(scores, dim=-1) 77 | 78 | # apply dropout - used in Transformer (default: 0) 79 | p_attn = self.dropout(p_attn) 80 | 81 | # dot product between p_attn and values 82 | # o_attn = torch.matmul(p_attn, values) 83 | o_attn = torch.einsum('b...ts,b...sm->b...tm', [p_attn, values]) 84 | return o_attn, p_attn 85 | -------------------------------------------------------------------------------- /kiwi/modules/common/distributions.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from torch.distributions import ( 18 | Normal, 19 | TransformedDistribution, 20 | constraints, 21 | identity_transform, 22 | ) 23 | 24 | 25 | class TruncatedNormal(TransformedDistribution): 26 | arg_constraints = { 27 | 'loc': constraints.real, 28 | 'scale': constraints.positive, 29 | 'lower_bound': constraints.real, 30 | 'upper_bound': constraints.real, 31 | } 32 | support = constraints.real 33 | has_rsample = True 34 | 35 | def __init__( 36 | self, loc, scale, lower_bound=0.0, upper_bound=1.0, validate_args=None 37 | ): 38 | base_dist = Normal(loc, scale) 39 | super(TruncatedNormal, self).__init__( 40 | base_dist, identity_transform, validate_args=validate_args 41 | ) 42 | self.lower_bound = lower_bound 43 | self.upper_bound = upper_bound 44 | 45 | def partition_function(self): 46 | return ( 47 | self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound) 48 | ).detach() + 1e-12 49 | 50 | @property 51 | def scale(self): 52 | return self.base_dist.scale 53 | 54 | @property 55 | def mean(self): 56 | r""" 57 | :math:`pdf = f(x; \mu, \sigma, a, b) = \frac{\phi(\xi)}{\sigma Z}` 58 | 59 | :math:`\xi=\frac{x-\mu}{\sigma}` 60 | 61 | :math:`\alpha=\frac{a-\mu}{\sigma}` 62 | 63 | :math:`\beta=\frac{b-\mu}{\sigma}` 64 | 65 | :math:`Z=\Phi(\beta)-\Phi(\alpha)` 66 | 67 | Return: 68 | :math:`\mu + \frac{\phi(\alpha)-\phi(\beta)}{Z}\sigma` 69 | 70 | """ 71 | mean = self.base_dist.mean + ( 72 | ( 73 | self.base_dist.scale ** 2 74 | * ( 75 | self.base_dist.log_prob(self.lower_bound).exp() 76 | - self.base_dist.log_prob(self.upper_bound).exp() 77 | ) 78 | ) 79 | / self.partition_function() 80 | ) 81 | return mean 82 | 83 | @property 84 | def variance(self): 85 | pdf_a = self.base_dist.log_prob(self.lower_bound).exp() 86 | pdf_b = self.base_dist.log_prob(self.upper_bound).exp() 87 | alpha = ( 88 | self.lower_bound - self.base_dist.mean 89 | ) * self.base_dist.scale.reciprocal() 90 | beta = ( 91 | self.upper_bound - self.base_dist.mean 92 | ) * self.base_dist.scale.reciprocal() 93 | z = self.partition_function() 94 | term1 = (alpha * pdf_a - beta * pdf_b) / z 95 | term2 = (pdf_a - pdf_b) / z 96 | return self.base_dist.scale ** 2 * (1 + term1 - term2 ** 2) 97 | 98 | def log_prob(self, value): 99 | log_value = self.base_dist.log_prob(value) 100 | log_prob = log_value - self.partition_function().log() 101 | return log_prob 102 | 103 | def cdf(self, value): 104 | if value <= self.lower_bound: 105 | return 0.0 106 | if value >= self.upper_bound: 107 | return 1.0 108 | unnormalized_cdf = self.base_dist.cdf(value) - self.base_dist.cdf( 109 | self.lower_bound 110 | ) 111 | return unnormalized_cdf / self.partition_function() 112 | 113 | def icdf(self, value): 114 | if self._validate_args: 115 | self._validate_sample(value) 116 | return self.base_dist.icdf( 117 | self.cdf(self.lower_bound) + value * self.partition_function() 118 | ) 119 | -------------------------------------------------------------------------------- /kiwi/modules/common/feedforward.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from collections import OrderedDict 18 | 19 | from torch import nn 20 | 21 | 22 | def feedforward( 23 | in_dim, 24 | n_layers, 25 | shrink=2, 26 | out_dim=None, 27 | activation=nn.Tanh, 28 | final_activation=False, 29 | dropout=0.0, 30 | ): 31 | """Constructor for FeedForward Layers""" 32 | dim = in_dim 33 | module_dict = OrderedDict() 34 | for layer_i in range(n_layers - 1): 35 | next_dim = dim // shrink 36 | module_dict['linear_{}'.format(layer_i)] = nn.Linear(dim, next_dim) 37 | module_dict['activation_{}'.format(layer_i)] = activation() 38 | module_dict['dropout_{}'.format(layer_i)] = nn.Dropout(dropout) 39 | dim = next_dim 40 | next_dim = out_dim or (dim // 2) 41 | module_dict['linear_{}'.format(n_layers - 1)] = nn.Linear(dim, next_dim) 42 | if final_activation: 43 | module_dict['activation_{}'.format(n_layers - 1)] = activation() 44 | return nn.Sequential(module_dict) 45 | -------------------------------------------------------------------------------- /kiwi/modules/common/layer_norm.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class TFLayerNorm(nn.Module): 22 | """Construct a layer normalization module with epsilon inside the 23 | square root (tensorflow style). 24 | 25 | This is equivalent to HuggingFace's BertLayerNorm module. 26 | """ 27 | 28 | def __init__(self, hidden_size, eps=1e-6): 29 | super().__init__() 30 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 31 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 32 | self.eps = eps 33 | 34 | def forward(self, x): 35 | u = x.mean(-1, keepdim=True) 36 | s = (x - u).pow(2).mean(-1, keepdim=True) 37 | x = (x - u) / torch.sqrt(s + self.eps) 38 | return self.gamma * x + self.beta 39 | -------------------------------------------------------------------------------- /kiwi/modules/common/positional_encoding.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import math 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class PositionalEncoding(nn.Module): 24 | """Absolute positional encoding mechanism. 25 | 26 | Arguments: 27 | max_seq_len: hypothetical maximum sequence length (usually 1000). 28 | hidden_size: embeddings size. 29 | """ 30 | 31 | def __init__(self, max_seq_len: int, hidden_size: int): 32 | super().__init__() 33 | 34 | position = torch.arange(0.0, max_seq_len).unsqueeze(1) 35 | neg_log_term = -math.log(10000.0) / hidden_size 36 | div_term = torch.exp(torch.arange(0.0, hidden_size, 2) * neg_log_term) 37 | 38 | pe = torch.zeros(max_seq_len, hidden_size, requires_grad=False) 39 | pe[:, 0::2] = torch.sin(position * div_term) 40 | 41 | # handle cases when hidden size is odd (cos will have one less than sin) 42 | pe_cos = torch.cos(position * div_term) 43 | if hidden_size % 2 == 1: 44 | pe_cos = pe_cos[:, :-1] 45 | pe[:, 1::2] = pe_cos 46 | 47 | pe = pe.unsqueeze(0) # add batch dimension 48 | self.register_buffer('pe', pe) 49 | self.hidden_size = hidden_size 50 | 51 | def forward(self, emb): 52 | # self.pe = self.pe.to(emb.device) 53 | assert emb.shape[1] <= self.pe.shape[1] 54 | return emb + self.pe[:, : emb.shape[1]] 55 | 56 | 57 | if __name__ == '__main__': 58 | from matplotlib import pyplot as plt 59 | import numpy as np 60 | 61 | batch_size = 8 62 | vocab_size = 1000 63 | emb_size = 20 64 | seq_len = 100 65 | max_seq_len = 5000 66 | d_i, d_j = 4, 10 67 | 68 | x_emb = torch.randint(vocab_size, size=(batch_size, seq_len)).long() 69 | x_rand = torch.randn(batch_size, seq_len, emb_size) 70 | x_zero = torch.zeros(batch_size, seq_len, emb_size) 71 | 72 | embed = nn.Embedding(vocab_size, emb_size) 73 | torch.nn.init.xavier_normal_(embed.weight) 74 | pe = PositionalEncoding(max_seq_len, emb_size) 75 | 76 | x_rand = pe(x_rand) 77 | x_emb = pe(embed(x_emb)).data 78 | x_zero = pe(x_zero) 79 | 80 | plt.figure(figsize=(15, 5)) 81 | plt.title('Random input') 82 | plt.plot(np.arange(seq_len), x_rand[0, :, d_i:d_j].numpy()) 83 | plt.legend(['dim %d' % d for d in range(d_i, d_j)]) 84 | plt.show() 85 | 86 | plt.figure(figsize=(15, 5)) 87 | plt.title('Embedding input') 88 | plt.plot(np.arange(seq_len), x_emb[0, :, d_i:d_j].numpy()) 89 | plt.legend(['dim %d' % d for d in range(d_i, d_j)]) 90 | plt.show() 91 | 92 | plt.figure(figsize=(15, 5)) 93 | plt.title('Zero input') 94 | plt.plot(np.arange(seq_len), x_zero[0, :, d_i:d_j].numpy()) 95 | plt.legend(['dim %d' % d for d in range(d_i, d_j)]) 96 | plt.show() 97 | -------------------------------------------------------------------------------- /kiwi/modules/common/scorer.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from math import sqrt 18 | 19 | import torch 20 | from torch import nn 21 | 22 | from kiwi.utils.tensors import make_mergeable_tensors 23 | 24 | 25 | class Scorer(nn.Module): 26 | """Score function for attention module. 27 | 28 | Arguments: 29 | scaled: whether to scale scores by `sqrt(hidden_size)` as proposed by the 30 | "Attention is All You Need" paper. 31 | """ 32 | 33 | def __init__(self, scaled: bool = True): 34 | super().__init__() 35 | self.scaled = scaled 36 | 37 | def scale(self, hidden_size: int) -> float: 38 | """Denominator for scaling the scores. 39 | 40 | Arguments: 41 | hidden_size: max hidden size between query and keys. 42 | 43 | Return: 44 | sqrt(hidden_size) if `scaled` is True, 1 otherwise. 45 | """ 46 | if self.scaled: 47 | return sqrt(hidden_size) 48 | return 1 49 | 50 | def forward( 51 | self, query: torch.FloatTensor, keys: torch.FloatTensor 52 | ) -> torch.FloatTensor: 53 | """Compute scores for each key of size n given the queries of size m. 54 | 55 | The three dots (...) represent any other dimensions, such as the 56 | number of heads (useful if you use a multi head attention). 57 | 58 | Arguments: 59 | query: query matrix ``(bs, ..., target_len, m)``. 60 | keys: keys matrix ``(bs, ..., source_len, n)``. 61 | 62 | Return: 63 | matrix representing scores between source words and target words 64 | ``(bs, ..., target_len, source_len)`` 65 | """ 66 | raise NotImplementedError 67 | 68 | 69 | class MLPScorer(Scorer): 70 | """MultiLayerPerceptron Scorer with variable nb of layers and neurons.""" 71 | 72 | def __init__( 73 | self, query_size, key_size, layer_sizes=None, activation=nn.Tanh, **kwargs 74 | ): 75 | super().__init__(**kwargs) 76 | if layer_sizes is None: 77 | layer_sizes = [(query_size + key_size) // 2] 78 | input_size = query_size + key_size # concatenate query and keys 79 | output_size = 1 # produce a scalar for each alignment 80 | layer_sizes = [input_size] + layer_sizes + [output_size] 81 | sizes = zip(layer_sizes[:-1], layer_sizes[1:]) 82 | layers = [] 83 | for n_in, n_out in sizes: 84 | layers.append(nn.Sequential(nn.Linear(n_in, n_out), activation())) 85 | self.layers = nn.ModuleList(layers) 86 | 87 | def forward(self, query, keys): 88 | x_query, x_keys = make_mergeable_tensors(query, keys) 89 | x = torch.cat((x_query, x_keys), dim=-1) 90 | for layer in self.layers: 91 | x = layer(x) 92 | return x.squeeze(-1) # remove last dimension 93 | -------------------------------------------------------------------------------- /kiwi/modules/sentence_level_output.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from torch import nn as nn 18 | 19 | from kiwi.modules.common.distributions import TruncatedNormal 20 | from kiwi.modules.common.feedforward import feedforward 21 | from kiwi.utils.tensors import sequence_mask 22 | 23 | 24 | class SentenceFromLogits(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.linear_out = nn.Linear(2, 1) 28 | nn.init.xavier_uniform_(self.linear_out.weight) 29 | nn.init.constant_(self.linear_out.bias, 0.0) 30 | 31 | self._loss_fn = nn.BCEWithLogitsLoss(reduction='sum') 32 | 33 | def forward(self, inputs, lengths): 34 | mask = sequence_mask(lengths, max_len=inputs.size(1)).float() 35 | average = (inputs * mask[..., None]).sum(1) / lengths[:, None].float() 36 | h = self.linear_out(average) 37 | return h 38 | 39 | def loss_fn(self, predicted, target): 40 | target = target[..., None] 41 | return self._loss_fn(predicted, target) 42 | 43 | 44 | class SentenceScoreRegression(nn.Module): 45 | def __init__( 46 | self, 47 | input_size, 48 | dropout=0.0, 49 | activation=nn.Tanh, 50 | final_activation=False, 51 | num_layers=3, 52 | ): 53 | super().__init__() 54 | 55 | self.sentence_pred = feedforward( 56 | input_size, 57 | n_layers=num_layers, 58 | out_dim=1, 59 | dropout=dropout, 60 | activation=activation, 61 | final_activation=final_activation, 62 | ) 63 | 64 | self.loss_fn = nn.MSELoss(reduction='sum') 65 | 66 | for p in self.parameters(): 67 | if len(p.shape) > 1: 68 | nn.init.xavier_uniform_(p) 69 | 70 | def forward(self, features, batch_inputs): 71 | sentence_scores = self.sentence_pred(features).squeeze() 72 | return sentence_scores 73 | 74 | 75 | class SentenceScoreDistribution(nn.Module): 76 | def __init__(self, input_size): 77 | super().__init__() 78 | 79 | self.sentence_pred = feedforward(input_size, n_layers=3, out_dim=1) 80 | # Predict truncated Gaussian distribution 81 | self.sentence_sigma = feedforward( 82 | input_size, 83 | n_layers=3, 84 | out_dim=1, 85 | activation=nn.Sigmoid, 86 | final_activation=True, 87 | ) 88 | 89 | self.loss_fn = self._loss_fn 90 | 91 | for p in self.parameters(): 92 | if len(p.shape) > 1: 93 | nn.init.xavier_uniform_(p) 94 | 95 | @staticmethod 96 | def _loss_fn(predicted, target): 97 | _, (mu, sigma) = predicted 98 | # Compute log-likelihood of x given mu, sigma 99 | dist = TruncatedNormal(mu, sigma + 1e-12, 0.0, 1.0) 100 | nll = -dist.log_prob(target) 101 | return nll.sum() 102 | 103 | def forward(self, features, batch_inputs): 104 | mu = self.sentence_pred(features).squeeze() 105 | sigma = self.sentence_sigma(features).squeeze() 106 | # Compute the mean of the truncated Gaussian as sentence score prediction 107 | dist = TruncatedNormal( 108 | mu.clone().detach(), sigma.clone().detach() + 1e-12, 0.0, 1.0 109 | ) 110 | sentence_scores = dist.mean 111 | return sentence_scores.squeeze(), (mu, sigma) 112 | 113 | 114 | class BinarySentenceScore(nn.Module): 115 | def __init__(self, input_size): 116 | super().__init__() 117 | self.sentence_pred = feedforward( 118 | input_size, n_layers=3, out_dim=2, activation=nn.Tanh 119 | ) 120 | 121 | self.loss_fn = nn.CrossEntropyLoss(reduction='sum') 122 | 123 | for p in self.parameters(): 124 | if len(p.shape) > 1: 125 | nn.init.xavier_uniform_(p) 126 | 127 | def forward(self, features, batch_inputs): 128 | return self.sentence_pred(features).squeeze() 129 | -------------------------------------------------------------------------------- /kiwi/modules/token_embeddings.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import math 18 | 19 | from torch import nn 20 | 21 | from kiwi.data.batch import BatchedSentence 22 | from kiwi.modules.common.layer_norm import TFLayerNorm 23 | from kiwi.modules.common.positional_encoding import PositionalEncoding 24 | from kiwi.utils.io import BaseConfig 25 | 26 | 27 | class TokenEmbeddings(nn.Module): 28 | class Config(BaseConfig): 29 | dim: int = 50 30 | freeze: bool = False 31 | dropout: float = 0.0 32 | use_position_embeddings: bool = False 33 | max_position_embeddings: int = 4000 34 | sparse_embeddings: bool = False 35 | scale_embeddings: bool = False 36 | input_layer_norm: bool = False 37 | 38 | def __init__(self, num_embeddings: int, pad_idx: int, config: Config, vectors=None): 39 | """A model for embedding a single type of tokens.""" 40 | super().__init__() 41 | self.pad_idx = pad_idx 42 | 43 | if vectors is not None: 44 | assert num_embeddings == vectors.size(0) 45 | 46 | self.embedding = nn.Embedding( 47 | num_embeddings=num_embeddings, 48 | embedding_dim=config.dim, 49 | padding_idx=pad_idx, 50 | sparse=config.sparse_embeddings, 51 | _weight=vectors, 52 | ) 53 | else: 54 | self.embedding = nn.Embedding( 55 | num_embeddings=num_embeddings, 56 | embedding_dim=config.dim, 57 | padding_idx=pad_idx, 58 | sparse=config.sparse_embeddings, 59 | ) 60 | nn.init.xavier_uniform_(self.embedding.weight) 61 | 62 | self._size = config.dim 63 | self._pe = config.max_position_embeddings 64 | 65 | if config.freeze: 66 | self.embedding.weight.requires_grad = False 67 | 68 | self.dropout = nn.Dropout(config.dropout) 69 | 70 | self.embeddings_scale_factor = 1 71 | if config.scale_embeddings: 72 | self.embeddings_scale_factor = math.sqrt(self._size) 73 | 74 | self.positional_encoding = None 75 | if config.use_position_embeddings: 76 | self.positional_encoding = PositionalEncoding(self._pe, self._size) 77 | 78 | self.layer_norm = None 79 | if config.input_layer_norm: 80 | self.layer_norm = TFLayerNorm(self._size) 81 | 82 | @property 83 | def num_embeddings(self): 84 | return self.embedding.num_embeddings 85 | 86 | def size(self): 87 | return self._size 88 | 89 | def forward(self, batch_input, *args): 90 | assert isinstance(batch_input, BatchedSentence) 91 | ids = batch_input.tensor 92 | 93 | embeddings = self.embedding(ids) 94 | embeddings = self.embeddings_scale_factor * embeddings 95 | 96 | if self.positional_encoding is not None: 97 | embeddings = self.positional_encoding(embeddings) 98 | 99 | if self.layer_norm is not None: 100 | embeddings = self.layer_norm(embeddings) 101 | 102 | embeddings = self.dropout(embeddings) 103 | 104 | return embeddings 105 | -------------------------------------------------------------------------------- /kiwi/modules/word_level_output.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class WordLevelOutput(nn.Module): 22 | def __init__( 23 | self, 24 | input_size, 25 | output_size, 26 | pad_idx, 27 | class_weights=None, 28 | remove_first=False, 29 | remove_last=False, 30 | ): 31 | super().__init__() 32 | 33 | self.pad_idx = pad_idx 34 | 35 | # Explicit check to avoid using 0 as False 36 | self.start_pos = None if remove_first is False or remove_first is None else 1 37 | self.stop_pos = None if remove_last is False or remove_last is None else -1 38 | 39 | self.linear = nn.Linear(input_size, output_size) 40 | 41 | self.loss_fn = nn.CrossEntropyLoss( 42 | reduction='sum', ignore_index=pad_idx, weight=class_weights 43 | ) 44 | 45 | nn.init.xavier_uniform_(self.linear.weight) 46 | nn.init.constant_(self.linear.bias, 0.0) 47 | 48 | def forward(self, features_tensor, batch_inputs=None): 49 | logits = self.linear(features_tensor) 50 | logits = logits[:, self.start_pos : self.stop_pos] 51 | return logits 52 | 53 | 54 | class GapTagsOutput(WordLevelOutput): 55 | def __init__( 56 | self, 57 | input_size, 58 | output_size, 59 | pad_idx, 60 | class_weights=None, 61 | remove_first=False, 62 | remove_last=False, 63 | ): 64 | super().__init__( 65 | input_size=2 * input_size, 66 | output_size=output_size, 67 | pad_idx=pad_idx, 68 | class_weights=class_weights, 69 | remove_first=False, 70 | remove_last=False, 71 | ) 72 | self.add_pad_start = 1 if remove_first is False or remove_first is None else 0 73 | self.add_pad_stop = 1 if remove_last is False or remove_last is None else 0 74 | 75 | def forward(self, features_tensor, batch_inputs=None): 76 | h_gaps = features_tensor 77 | if self.add_pad_start or self.add_pad_stop: 78 | # Pad dim=1 79 | num_of_pads = self.add_pad_start + self.add_pad_stop 80 | h_gaps = nn.functional.pad( 81 | h_gaps, 82 | pad=[0, 0] * (len(h_gaps.shape) - num_of_pads) 83 | + [self.add_pad_start, self.add_pad_stop], 84 | value=0, 85 | ) 86 | h_gaps = torch.cat((h_gaps[:, :-1], h_gaps[:, 1:]), dim=-1) 87 | logits = super().forward(h_gaps, batch_inputs) 88 | return logits 89 | -------------------------------------------------------------------------------- /kiwi/systems/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | """Use systems in a plugin way. 18 | 19 | Solution based on https://julienharbulot.com/python-dynamical-import.html. 20 | """ 21 | from importlib import import_module 22 | from pathlib import Path 23 | from pkgutil import iter_modules 24 | 25 | # iterate through the modules in the current package 26 | package_dir = Path(__file__).resolve().parent 27 | for (_, module_name, _) in iter_modules([package_dir]): 28 | # import the module and iterate through its attributes 29 | module = import_module(f"{__name__}.{module_name}") 30 | # for attribute_name in dir(module): 31 | # attribute = getattr(module, attribute_name) 32 | # 33 | # try: 34 | # if isclass(attribute) and issubclass(attribute, QESystem): 35 | # # Add the class to this package's variables 36 | # globals()[attribute_name] = attribute 37 | # except TypeError: 38 | # pass 39 | -------------------------------------------------------------------------------- /kiwi/systems/_meta_module.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import json 18 | import logging 19 | from abc import ABCMeta, abstractmethod 20 | from collections import OrderedDict 21 | 22 | import torch 23 | import torch.nn as nn 24 | 25 | import kiwi 26 | from kiwi import constants as const 27 | from kiwi.utils.io import BaseConfig, load_torch_file 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class Serializable(metaclass=ABCMeta): 33 | subclasses = {} 34 | 35 | @classmethod 36 | def register_subclass(cls, subclass): 37 | cls.subclasses[subclass.__name__] = subclass 38 | return subclass 39 | 40 | @classmethod 41 | def retrieve_subclass(cls, subclass_name): 42 | subclass = cls.subclasses.get(subclass_name) 43 | if subclass is None: 44 | raise KeyError( 45 | f'{subclass_name} is not a registered subclass of {cls.__name__}' 46 | ) 47 | return subclass 48 | 49 | @classmethod 50 | def load(cls, path): 51 | model_dict = load_torch_file(path) 52 | return cls.from_dict(model_dict) 53 | 54 | def save(self, path): 55 | torch.save(self.to_dict(), path) 56 | 57 | @classmethod 58 | @abstractmethod 59 | def from_dict(cls, *args, **kwargs): 60 | pass 61 | 62 | @classmethod 63 | @abstractmethod 64 | def to_dict(cls, include_state=True): 65 | pass 66 | 67 | 68 | class MetaModule(nn.Module, Serializable, metaclass=ABCMeta): 69 | class Config(BaseConfig, metaclass=ABCMeta): 70 | pass 71 | 72 | def __init__(self, config: Config): 73 | """Base module used for several model layers and modules. 74 | 75 | Arguments: 76 | config: a ``MetaModule.Config`` object. 77 | """ 78 | super().__init__() 79 | 80 | self.config = config 81 | 82 | @classmethod 83 | def from_dict(cls, module_dict, **kwargs): 84 | module_cls = cls.retrieve_subclass(module_dict['class_name']) 85 | config = module_cls.Config(**module_dict[const.CONFIG]) 86 | module = module_cls(config=config, **kwargs) 87 | 88 | state_dict = module_dict.get(const.STATE_DICT) 89 | if state_dict: 90 | not_loaded_keys = module.load_state_dict(state_dict) 91 | logger.debug(f'Loaded encoder; extraneous keys: {not_loaded_keys}') 92 | 93 | return module 94 | 95 | def to_dict(self, include_state=True): 96 | module_dict = OrderedDict( 97 | { 98 | '__version__': kiwi.__version__, 99 | 'class_name': self.__class__.__name__, 100 | const.CONFIG: json.loads(self.config.json()), 101 | const.STATE_DICT: self.state_dict() if include_state else None, 102 | } 103 | ) 104 | return module_dict 105 | -------------------------------------------------------------------------------- /kiwi/systems/bert.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from typing import Any, Dict 19 | 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder 22 | from kiwi.systems.decoders.linear import LinearDecoder 23 | from kiwi.systems.encoders.bert import BertEncoder 24 | from kiwi.systems.outputs.quality_estimation import QEOutputs 25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs 26 | from kiwi.systems.qe_system import QESystem 27 | from kiwi.utils.io import BaseConfig 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class ModelConfig(BaseConfig): 33 | encoder: BertEncoder.Config = BertEncoder.Config() 34 | decoder: LinearDecoder.Config = LinearDecoder.Config() 35 | outputs: QEOutputs.Config = QEOutputs.Config() 36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config() 37 | 38 | 39 | @QESystem.register_subclass 40 | class Bert(QESystem): 41 | """BERT-based Predictor-Estimator model.""" 42 | 43 | class Config(QESystem.Config): 44 | model: ModelConfig 45 | 46 | def __init__( 47 | self, 48 | config, 49 | data_config: WMTQEDataset.Config = None, 50 | module_dict: Dict[str, Any] = None, 51 | ): 52 | super().__init__(config, data_config=data_config) 53 | 54 | if module_dict: 55 | # Load modules and weights 56 | self._load_dict(module_dict) 57 | elif self.config.load_encoder: 58 | self._load_encoder(self.config.load_encoder) 59 | else: 60 | # Initialize data processing 61 | self.data_encoders = WMTQEDataEncoder( 62 | config=self.config.data_processing, 63 | field_encoders=BertEncoder.input_data_encoders( 64 | self.config.model.encoder 65 | ), 66 | ) 67 | 68 | # Add possibly missing fields, like outputs 69 | if self.config.load_vocabs: 70 | self.data_encoders.load_vocabularies(self.config.load_vocabs) 71 | if self.train_dataset: 72 | self.data_encoders.fit_vocabularies(self.train_dataset) 73 | 74 | # Input to features 75 | if not self.encoder: 76 | self.encoder = BertEncoder( 77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder 78 | ) 79 | 80 | # Features to output 81 | if not self.decoder: 82 | self.decoder = LinearDecoder( 83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder 84 | ) 85 | 86 | # Output layers 87 | if not self.outputs: 88 | self.outputs = QEOutputs( 89 | inputs_dims=self.decoder.size(), 90 | vocabs=self.data_encoders.vocabularies, 91 | config=self.config.model.outputs, 92 | ) 93 | 94 | if not self.tlm_outputs: 95 | self.tlm_outputs = TLMOutputs( 96 | inputs_dims=self.encoder.size(), 97 | vocabs=self.data_encoders.vocabularies, 98 | config=self.config.model.tlm_outputs, 99 | ) 100 | -------------------------------------------------------------------------------- /kiwi/systems/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/systems/decoders/linear.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from collections import OrderedDict 18 | from typing import Dict 19 | 20 | import torch 21 | from pydantic import confloat 22 | from torch import nn 23 | 24 | from kiwi import constants as const 25 | from kiwi.data.batch import MultiFieldBatch 26 | from kiwi.systems._meta_module import MetaModule 27 | from kiwi.utils.io import BaseConfig 28 | 29 | 30 | @MetaModule.register_subclass 31 | class LinearDecoder(MetaModule): 32 | class Config(BaseConfig): 33 | hidden_size: int = 250 34 | 'Size of hidden layer' 35 | 36 | dropout: confloat(ge=0.0, le=1.0) = 0.0 37 | 38 | bottleneck_size: int = 100 39 | 40 | def __init__(self, inputs_dims, config): 41 | super().__init__(config=config) 42 | 43 | self.features_dims = inputs_dims 44 | 45 | # Build Model # 46 | self.linear_outs = nn.ModuleDict() 47 | self._sizes = {} 48 | if const.TARGET in self.features_dims: 49 | self.linear_outs[const.TARGET] = nn.Sequential( 50 | nn.Linear(self.features_dims[const.TARGET], self.config.hidden_size), 51 | nn.Tanh(), 52 | ) 53 | self._sizes[const.TARGET] = self.config.hidden_size 54 | if const.SOURCE in self.features_dims: 55 | self.linear_outs[const.SOURCE] = nn.Sequential( 56 | nn.Linear(self.features_dims[const.SOURCE], self.config.hidden_size), 57 | nn.Tanh(), 58 | ) 59 | self._sizes[const.SOURCE] = self.config.hidden_size 60 | 61 | self.dropout = nn.Dropout(self.config.dropout) 62 | 63 | for p in self.parameters(): 64 | if len(p.shape) > 1: 65 | nn.init.xavier_uniform_(p) 66 | 67 | if const.TARGET_SENTENCE in self.features_dims: 68 | linear_layers = [ 69 | nn.Linear( 70 | self.features_dims[const.TARGET_SENTENCE], 71 | self.config.bottleneck_size, 72 | ), 73 | nn.Tanh(), 74 | nn.Dropout(self.config.dropout), 75 | ] 76 | linear_layers.extend( 77 | [ 78 | nn.Linear(self.config.bottleneck_size, self.config.hidden_size), 79 | nn.Tanh(), 80 | nn.Dropout(self.config.dropout), 81 | ] 82 | ) 83 | self.linear_outs[const.TARGET_SENTENCE] = nn.Sequential(*linear_layers) 84 | self._sizes[const.TARGET_SENTENCE] = self.config.hidden_size 85 | 86 | def size(self, field=None): 87 | if field: 88 | return self._sizes[field] 89 | return self._sizes 90 | 91 | def forward(self, features: Dict[str, torch.Tensor], batch_inputs: MultiFieldBatch): 92 | output_features = OrderedDict() 93 | 94 | if const.TARGET in features: 95 | features_tensor = self.dropout(features[const.TARGET]) 96 | output_features[const.TARGET] = self.linear_outs[const.TARGET]( 97 | features_tensor 98 | ) 99 | if const.SOURCE in features: 100 | features_tensor = self.dropout(features[const.SOURCE]) 101 | output_features[const.SOURCE] = self.linear_outs[const.SOURCE]( 102 | features_tensor 103 | ) 104 | if const.TARGET_SENTENCE in features: 105 | features_tensor = self.linear_outs[const.TARGET_SENTENCE]( 106 | features[const.TARGET_SENTENCE] 107 | ) 108 | output_features[const.TARGET_SENTENCE] = features_tensor 109 | 110 | return output_features 111 | -------------------------------------------------------------------------------- /kiwi/systems/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/systems/nuqe.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from typing import Any, Dict 19 | 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder 22 | from kiwi.systems.decoders.nuqe import NuQEDecoder 23 | from kiwi.systems.encoders.quetch import QUETCHEncoder 24 | from kiwi.systems.outputs.quality_estimation import QEOutputs 25 | from kiwi.systems.qe_system import QESystem 26 | from kiwi.utils.io import BaseConfig 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class ModelConfig(BaseConfig): 32 | encoder: QUETCHEncoder.Config 33 | decoder: NuQEDecoder.Config 34 | outputs: QEOutputs.Config 35 | 36 | 37 | @QESystem.register_subclass 38 | class NuQE(QESystem): 39 | """Neural Quality Estimation (NuQE) model for word level quality estimation.""" 40 | 41 | class Config(QESystem.Config): 42 | model: ModelConfig 43 | 44 | def __init__( 45 | self, 46 | config, 47 | data_config: WMTQEDataset.Config = None, 48 | module_dict: Dict[str, Any] = None, 49 | ): 50 | super().__init__(config, data_config=data_config) 51 | 52 | if module_dict: 53 | # Load modules and weights 54 | self._load_dict(module_dict) 55 | elif self.config.load_encoder: 56 | logger.warning( 57 | f'NuQE does not support loading the encoder; ignoring option ' 58 | f'`load_encoder={self.config.load_encoder}`' 59 | ) 60 | else: 61 | # Initialize data processing 62 | self.data_encoders = WMTQEDataEncoder( 63 | config=self.config.data_processing, 64 | field_encoders=QUETCHEncoder.input_data_encoders( 65 | self.config.model.encoder 66 | ), 67 | ) 68 | 69 | # Add possibly missing fields, like outputs 70 | if self.config.load_vocabs: 71 | self.data_encoders.load_vocabularies(self.config.load_vocabs) 72 | if self.train_dataset: 73 | self.data_encoders.fit_vocabularies(self.train_dataset) 74 | 75 | # Input to features 76 | if not self.encoder: 77 | self.encoder = QUETCHEncoder( 78 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder 79 | ) 80 | 81 | # Features to output 82 | if not self.decoder: 83 | self.decoder = NuQEDecoder( 84 | inputs_dims=self.encoder.size(), config=self.config.model.decoder 85 | ) 86 | 87 | # Output layers 88 | if not self.outputs: 89 | self.outputs = QEOutputs( 90 | inputs_dims=self.decoder.size(), 91 | vocabs=self.data_encoders.vocabularies, 92 | config=self.config.model.outputs, 93 | ) 94 | self.tlm_outputs = None 95 | -------------------------------------------------------------------------------- /kiwi/systems/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/systems/predictor.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from typing import Any, Dict 19 | 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.data.encoders.parallel_data_encoder import ParallelDataEncoder 22 | from kiwi.systems.encoders.predictor import PredictorEncoder 23 | from kiwi.systems.outputs.translation_language_model import TLMOutputs 24 | from kiwi.systems.tlm_system import TLMSystem 25 | from kiwi.utils.io import BaseConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class ModelConfig(BaseConfig): 31 | encoder: PredictorEncoder.Config = PredictorEncoder.Config() 32 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config() 33 | 34 | 35 | @TLMSystem.register_subclass 36 | class Predictor(TLMSystem): 37 | """Predictor TLM, used for the Predictor-Estimator QE model (proposed in 2017).""" 38 | 39 | class Config(TLMSystem.Config): 40 | model: ModelConfig 41 | 42 | def __init__( 43 | self, 44 | config: Config, 45 | data_config: WMTQEDataset.Config = None, 46 | module_dict: Dict[str, Any] = None, 47 | ): 48 | super().__init__(config, data_config=data_config) 49 | 50 | if module_dict: 51 | # Load modules and weights 52 | self._load_dict(module_dict) 53 | else: 54 | # Initialize data processing 55 | self.data_encoders = ParallelDataEncoder( 56 | config=self.config.data_processing, 57 | field_encoders=PredictorEncoder.input_data_encoders( 58 | self.config.model.encoder 59 | ), 60 | ) 61 | 62 | # Add possibly missing fields, like outputs 63 | if self.config.load_vocabs: 64 | self.data_encoders.load_vocabularies(self.config.load_vocabs) 65 | if self.train_dataset: 66 | self.data_encoders.fit_vocabularies(self.train_dataset) 67 | 68 | # Input to features 69 | if not self.encoder: 70 | self.encoder = PredictorEncoder( 71 | vocabs=self.data_encoders.vocabularies, 72 | config=self.config.model.encoder, 73 | pretraining=True, 74 | ) 75 | 76 | # Output 77 | if not self.tlm_outputs: 78 | self.tlm_outputs = TLMOutputs( 79 | inputs_dims=self.encoder.size(), 80 | vocabs=self.data_encoders.vocabularies, 81 | config=self.config.model.tlm_outputs, 82 | pretraining=True, 83 | ) 84 | -------------------------------------------------------------------------------- /kiwi/systems/predictor_estimator.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from typing import Any, Dict 19 | 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder 22 | from kiwi.systems.decoders.estimator import EstimatorDecoder 23 | from kiwi.systems.encoders.predictor import PredictorEncoder 24 | from kiwi.systems.outputs.quality_estimation import QEOutputs 25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs 26 | from kiwi.systems.qe_system import QESystem 27 | from kiwi.utils.io import BaseConfig 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class ModelConfig(BaseConfig): 33 | encoder: PredictorEncoder.Config = PredictorEncoder.Config() 34 | decoder: EstimatorDecoder.Config = EstimatorDecoder.Config() 35 | outputs: QEOutputs.Config = QEOutputs.Config() 36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config() 37 | 38 | 39 | @QESystem.register_subclass 40 | class PredictorEstimator(QESystem): 41 | """Predictor-Estimator QE model (proposed in 2017).""" 42 | 43 | class Config(QESystem.Config): 44 | model: ModelConfig 45 | 46 | def __init__( 47 | self, 48 | config, 49 | data_config: WMTQEDataset.Config = None, 50 | module_dict: Dict[str, Any] = None, 51 | ): 52 | super().__init__(config, data_config=data_config) 53 | 54 | if module_dict: 55 | # Load modules and weights 56 | self._load_dict(module_dict) 57 | elif self.config.load_encoder: 58 | self._load_encoder(self.config.load_encoder) 59 | else: 60 | # Initialize data processing 61 | self.data_encoders = WMTQEDataEncoder( 62 | config=self.config.data_processing, 63 | field_encoders=PredictorEncoder.input_data_encoders( 64 | self.config.model.encoder 65 | ), 66 | ) 67 | 68 | # Add possibly missing fields, like outputs 69 | if self.config.load_vocabs: 70 | self.data_encoders.load_vocabularies(self.config.load_vocabs) 71 | if self.train_dataset: 72 | self.data_encoders.fit_vocabularies(self.train_dataset) 73 | 74 | # Input to features 75 | if not self.encoder: 76 | self.encoder = PredictorEncoder( 77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder 78 | ) 79 | 80 | # Features to output 81 | if not self.decoder: 82 | self.decoder = EstimatorDecoder( 83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder 84 | ) 85 | 86 | # Output layers 87 | if not self.outputs: 88 | self.outputs = QEOutputs( 89 | inputs_dims=self.decoder.size(), 90 | vocabs=self.data_encoders.vocabularies, 91 | config=self.config.model.outputs, 92 | ) 93 | 94 | if not self.tlm_outputs: 95 | self.tlm_outputs = TLMOutputs( 96 | inputs_dims=self.encoder.size(), 97 | vocabs=self.data_encoders.vocabularies, 98 | config=self.config.model.tlm_outputs, 99 | ) 100 | -------------------------------------------------------------------------------- /kiwi/systems/xlm.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from typing import Any, Dict 19 | 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder 22 | from kiwi.systems.decoders.linear import LinearDecoder 23 | from kiwi.systems.encoders.xlm import XLMEncoder 24 | from kiwi.systems.outputs.quality_estimation import QEOutputs 25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs 26 | from kiwi.systems.qe_system import QESystem 27 | from kiwi.utils.io import BaseConfig 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class ModelConfig(BaseConfig): 33 | encoder: XLMEncoder.Config = XLMEncoder.Config() 34 | decoder: LinearDecoder.Config = LinearDecoder.Config() 35 | outputs: QEOutputs.Config = QEOutputs.Config() 36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config() 37 | 38 | 39 | @QESystem.register_subclass 40 | class XLM(QESystem): 41 | """XLM-based model for word level quality estimation.""" 42 | 43 | class Config(QESystem.Config): 44 | model: ModelConfig 45 | 46 | def __init__( 47 | self, 48 | config, 49 | data_config: WMTQEDataset.Config = None, 50 | module_dict: Dict[str, Any] = None, 51 | ): 52 | super().__init__(config, data_config=data_config) 53 | 54 | if module_dict: 55 | # Load modules and weights 56 | self._load_dict(module_dict) 57 | elif self.config.load_encoder: 58 | self._load_encoder(self.config.load_encoder) 59 | else: 60 | # Initialize data processing 61 | self.data_encoders = WMTQEDataEncoder( 62 | config=self.config.data_processing, 63 | field_encoders=XLMEncoder.input_data_encoders( 64 | self.config.model.encoder 65 | ), 66 | ) 67 | 68 | # Add possibly missing fields, like outputs 69 | if self.config.load_vocabs: 70 | self.data_encoders.load_vocabularies(self.config.load_vocabs) 71 | if self.train_dataset: 72 | self.data_encoders.fit_vocabularies(self.train_dataset) 73 | 74 | # Input to features 75 | if not self.encoder: 76 | self.encoder = XLMEncoder( 77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder 78 | ) 79 | 80 | # Features to output 81 | if not self.decoder: 82 | self.decoder = LinearDecoder( 83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder 84 | ) 85 | 86 | # Output layers 87 | if not self.outputs: 88 | self.outputs = QEOutputs( 89 | inputs_dims=self.decoder.size(), 90 | vocabs=self.data_encoders.vocabularies, 91 | config=self.config.model.outputs, 92 | ) 93 | 94 | if not self.tlm_outputs: 95 | self.tlm_outputs = TLMOutputs( 96 | inputs_dims=self.encoder.size(), 97 | vocabs=self.data_encoders.vocabularies, 98 | config=self.config.model.tlm_outputs, 99 | ) 100 | -------------------------------------------------------------------------------- /kiwi/systems/xlmroberta.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from typing import Any, Dict 19 | 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder 22 | from kiwi.systems.decoders.linear import LinearDecoder 23 | from kiwi.systems.encoders.xlmroberta import XLMRobertaEncoder 24 | from kiwi.systems.outputs.quality_estimation import QEOutputs 25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs 26 | from kiwi.systems.qe_system import QESystem 27 | from kiwi.utils.io import BaseConfig 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class ModelConfig(BaseConfig): 33 | encoder: XLMRobertaEncoder.Config = XLMRobertaEncoder.Config() 34 | decoder: LinearDecoder.Config = LinearDecoder.Config() 35 | outputs: QEOutputs.Config = QEOutputs.Config() 36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config() 37 | 38 | 39 | @QESystem.register_subclass 40 | class XLMRoberta(QESystem): 41 | """XLMRoberta-based Predictor-Estimator model""" 42 | 43 | class Config(QESystem.Config): 44 | model: ModelConfig 45 | 46 | def __init__( 47 | self, 48 | config, 49 | data_config: WMTQEDataset.Config = None, 50 | module_dict: Dict[str, Any] = None, 51 | ): 52 | super().__init__(config, data_config=data_config) 53 | 54 | if module_dict: 55 | # Load modules and weights 56 | self._load_dict(module_dict) 57 | elif self.config.load_encoder: 58 | self._load_encoder(self.config.load_encoder) 59 | else: 60 | # Initialize data processing 61 | self.data_encoders = WMTQEDataEncoder( 62 | config=self.config.data_processing, 63 | field_encoders=XLMRobertaEncoder.input_data_encoders( 64 | self.config.model.encoder 65 | ), 66 | ) 67 | 68 | # Add possibly missing fields, like outputs 69 | if self.config.load_vocabs: 70 | self.data_encoders.load_vocabularies(self.config.load_vocabs) 71 | if self.train_dataset: 72 | self.data_encoders.fit_vocabularies(self.train_dataset) 73 | 74 | # Input to features 75 | if not self.encoder: 76 | self.encoder = XLMRobertaEncoder( 77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder 78 | ) 79 | 80 | # Features to output 81 | if not self.decoder: 82 | self.decoder = LinearDecoder( 83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder 84 | ) 85 | 86 | # Output layers 87 | if not self.outputs: 88 | self.outputs = QEOutputs( 89 | inputs_dims=self.decoder.size(), 90 | vocabs=self.data_encoders.vocabularies, 91 | config=self.config.model.outputs, 92 | ) 93 | 94 | if not self.tlm_outputs: 95 | self.tlm_outputs = TLMOutputs( 96 | inputs_dims=self.encoder.size(), 97 | vocabs=self.data_encoders.vocabularies, 98 | config=self.config.model.tlm_outputs, 99 | ) 100 | -------------------------------------------------------------------------------- /kiwi/training/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/training/callbacks.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | import textwrap 19 | 20 | import numpy as np 21 | from pytorch_lightning import Callback 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class BestMetricsInfo(Callback): 27 | """Class for logging current training metrics along with the best so far.""" 28 | 29 | def __init__( 30 | self, 31 | monitor: str = 'val_loss', 32 | min_delta: float = 0.0, 33 | verbose: bool = True, 34 | mode: str = 'auto', 35 | ): 36 | super().__init__() 37 | 38 | self.monitor = monitor 39 | self.min_delta = min_delta 40 | self.verbose = verbose 41 | 42 | mode_dict = { 43 | 'min': np.less, 44 | 'max': np.greater, 45 | 'auto': np.greater if 'acc' in self.monitor else np.less, 46 | } 47 | 48 | if mode not in mode_dict: 49 | logger.info( 50 | f'BestMetricsInfo mode {mode} is unknown, fallback to auto mode.' 51 | ) 52 | mode = 'auto' 53 | 54 | self.monitor_op = mode_dict[mode] 55 | self.min_delta *= 1 if self.monitor_op == np.greater else -1 56 | 57 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 58 | self.best_epoch = -1 59 | self.best_metrics = {} 60 | 61 | def on_train_begin(self, trainer, pl_module): 62 | # Allow instances to be re-used 63 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 64 | self.best_epoch = -1 65 | self.best_metrics = {} 66 | 67 | def on_train_end(self, trainer, pl_module): 68 | if self.best_epoch > 0 and self.verbose > 0: 69 | metrics_message = textwrap.fill( 70 | ', '.join( 71 | [ 72 | '{}: {:0.4f}'.format(k, v) 73 | for k, v in self.best_metrics.items() 74 | if k.startswith('val_') 75 | ] 76 | ), 77 | width=80, 78 | initial_indent='\t', 79 | subsequent_indent='\t', 80 | ) 81 | best_path = trainer.checkpoint_callback.best_model_path 82 | if not best_path: 83 | best_path = ( 84 | "model was not saved; check flags in Trainer if this is not " 85 | "expected" 86 | ) 87 | logger.info( 88 | f'Epoch {self.best_epoch} had the best validation metric:\n' 89 | f'{metrics_message} \n' 90 | f'\t({best_path})\n' 91 | ) 92 | 93 | def on_validation_end(self, trainer, pl_module): 94 | metrics = trainer.callback_metrics 95 | 96 | current = metrics.get(self.monitor) 97 | if self.monitor_op(current - self.min_delta, self.best): 98 | self.best = current 99 | self.best_epoch = trainer.current_epoch 100 | self.best_metrics = metrics.copy() # Copy or it gets overwritten 101 | if self.verbose > 0: 102 | logger.info('Best validation so far.') 103 | else: 104 | metrics_message = textwrap.fill( 105 | ', '.join( 106 | [ 107 | f'{k}: {v:0.4f}' 108 | for k, v in self.best_metrics.items() 109 | if k.startswith('val_') 110 | ] 111 | ), 112 | width=80, 113 | initial_indent='\t', 114 | subsequent_indent='\t', 115 | ) 116 | logger.info( 117 | f'Best validation so far was in epoch {self.best_epoch}:\n' 118 | f'{metrics_message} \n' 119 | ) 120 | -------------------------------------------------------------------------------- /kiwi/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | -------------------------------------------------------------------------------- /kiwi/utils/data_structures.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | from collections import OrderedDict 18 | 19 | from kiwi import constants as const 20 | 21 | 22 | class DefaultFrozenDict(OrderedDict): 23 | def __init__(self, mapping=None, default_key=const.UNK): 24 | if mapping is None: 25 | super().__init__() 26 | else: 27 | super().__init__(mapping) 28 | self._default_key = default_key 29 | 30 | def __getitem__(self, k): 31 | default_id = self.get(self._default_key) 32 | item = self.get(k, default_id) 33 | if item is None: 34 | raise KeyError( 35 | f"'{k}' (and default '{self._default_key}' not found either)" 36 | ) 37 | return item 38 | -------------------------------------------------------------------------------- /kiwi/utils/io.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import logging 18 | from pathlib import Path 19 | 20 | import torch 21 | from pydantic import BaseModel, Extra 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class BaseConfig(BaseModel): 27 | """Base class for all pydantic configs. Used to configure base behaviour of configs. 28 | """ 29 | 30 | class Config: 31 | # Throws an error whenever an extra key is provided, effectively making parsing, 32 | # strict 33 | extra = Extra.forbid 34 | 35 | 36 | def default_map_location(storage, loc): 37 | return storage 38 | 39 | 40 | def load_torch_file(file_path, map_location=None): 41 | file_path = Path(file_path) 42 | if not file_path.exists(): 43 | raise ValueError(f'Torch file not found: {file_path}') 44 | 45 | try: 46 | if map_location is None: 47 | map_location = default_map_location 48 | file_dict = torch.load(file_path, map_location=map_location) 49 | except ModuleNotFoundError as e: 50 | # Caused, e.g., by moving the Vocabulary or DefaultFrozenDict classes 51 | logger.info( 52 | 'Trying to load a slightly outdated file and encountered an issue when ' 53 | 'unpickling; trying to work around it.' 54 | ) 55 | if e.name == 'kiwi.data.utils': 56 | import sys 57 | from kiwi.utils import data_structures 58 | 59 | sys.modules['kiwi.data.utils'] = data_structures 60 | file_dict = torch.load(file_path, map_location=map_location) 61 | del sys.modules['kiwi.data.utils'] 62 | elif e.name == 'torchtext': 63 | import sys 64 | from kiwi.data import vocabulary 65 | 66 | vocabulary.Vocab = vocabulary.Vocabulary 67 | sys.modules['torchtext'] = '' 68 | sys.modules['torchtext.vocab'] = vocabulary 69 | file_dict = torch.load(file_path, map_location=map_location) 70 | del sys.modules['torchtext.vocab'] 71 | del sys.modules['torchtext'] 72 | else: 73 | raise e 74 | 75 | return file_dict 76 | 77 | 78 | def save_file(file_path, data, token_sep=' ', example_sep='\n'): 79 | if data and isinstance(data[0], list): 80 | data = [token_sep.join(map(str, sentence)) for sentence in data] 81 | else: 82 | data = map(str, data) 83 | example_str = example_sep.join(data) + '\n' 84 | Path(file_path).write_text(example_str) 85 | 86 | 87 | def save_predicted_probabilities(directory, predictions, prefix=''): 88 | directory = Path(directory) 89 | directory.mkdir(parents=True, exist_ok=True) 90 | for key, preds in predictions.items(): 91 | if prefix: 92 | key = f'{prefix}.{key}' 93 | output_path = Path(directory, key) 94 | logger.info(f'Saving {key} predictions to {output_path}') 95 | save_file(output_path, preds, token_sep=' ', example_sep='\n') 96 | 97 | 98 | def read_file(path): 99 | """Read a file into a list of lists of words.""" 100 | with Path(path).open('r', encoding='utf8') as f: 101 | return [[token for token in line.strip().split()] for line in f] 102 | 103 | 104 | def target_gaps_to_target(batch): 105 | """Extract target tags from wmt18 format file.""" 106 | return batch[1::2] 107 | 108 | 109 | def target_gaps_to_gaps(batch): 110 | """Extract gap tags from wmt18 format file.""" 111 | return batch[::2] 112 | 113 | 114 | def generate_slug(text, delimiter="-"): 115 | """Convert text to a normalized "slug" without whitespace. 116 | 117 | Borrowed from the nice https://humanfriendly.readthedocs.io, by Peter Odding. 118 | 119 | Arguments: 120 | text: the original text, for example ``Some Random Text!``. 121 | delimiter: the delimiter to use for separating words 122 | (defaults to the ``-`` character). 123 | 124 | Return: 125 | the slug text, for example ``some-random-text``. 126 | 127 | Raise: 128 | :exc:`~exceptions.ValueError` when the provided text is nonempty but results 129 | in an empty slug. 130 | """ 131 | import re 132 | 133 | slug = text.lower() 134 | escaped = delimiter.replace("\\", "\\\\") 135 | slug = re.sub("[^a-z0-9]+", escaped, slug) 136 | slug = slug.strip(delimiter) 137 | if text and not slug: 138 | msg = "The provided text %r results in an empty slug!" 139 | raise ValueError(format(msg, text)) 140 | return slug 141 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Configuration file as per PEP 518 2 | # https://www.python.org/dev/peps/pep-0518/ 3 | 4 | [tool.poetry] 5 | name = "openkiwi" 6 | version = "2.1.0" 7 | description = "Machine Translation Quality Estimation Toolkit" 8 | authors = ["AI Research, Unbabel "] 9 | license = "AGPL-3.0" 10 | readme = 'README.md' 11 | homepage = 'https://github.com/Unbabel/OpenKiwi' 12 | repository = 'https://github.com/Unbabel/OpenKiwi' 13 | documentation = 'https://unbabel.github.io/OpenKiwi' 14 | keywords = ['OpenKiwi', 'Quality Estimation', 'Machine Translation', 'Unbabel'] 15 | classifiers = [ 16 | 'Development Status :: 4 - Beta', 17 | 'Environment :: Console', 18 | 'Intended Audience :: Science/Research', 19 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 20 | ] 21 | packages = [ 22 | {include = "kiwi"}, 23 | ] 24 | include = ['pyproject.toml', 'CHANGELOG.md', 'LICENSE', 'CONTRIBUTING.md'] 25 | 26 | 27 | [tool.poetry.scripts] 28 | kiwi = 'kiwi.__main__:main' 29 | 30 | 31 | [tool.poetry.dependencies] 32 | python = "^3.7" 33 | torch = ">=1.4.0, <1.7.0" 34 | tqdm = "^4.29" 35 | numpy = "^1.18" 36 | more-itertools = "^8.0.0" 37 | scipy = "^1.2" 38 | pyyaml = "^5.1.2" 39 | pytorch-nlp = "^0.5.0" 40 | transformers = "^3.0.2" 41 | pydantic = "^1.5" 42 | docopt = "^0.6.2" 43 | omegaconf = "^1.4.1" 44 | typing-extensions = "^3.7.4" 45 | hydra-core = "^0.11.3" 46 | pytorch-lightning = "^0.8.4" 47 | mlflow = {version = "^1.11.0", optional = true, extras = ["mlflow"]} 48 | optuna = {version = "^2.2.0", optional = true, extras = ["search"]} 49 | plotly = {version = "^4.11.0", optional = true, extras = ["search"]} 50 | sklearn = {version = "^0.0", optional = true, extras = ["search"]} 51 | 52 | [tool.poetry.dev-dependencies] 53 | tox = "^3.7" 54 | pytest = "^4.1" 55 | flake8 = "^3.8" 56 | isort = "^4.3" 57 | black = {version = "^19.10-beta.0",allow-prereleases = true} 58 | pytest-cov = "^2.8.1" 59 | pytest-sugar = "^0.9.3" 60 | sphinx = "^3.0" 61 | recommonmark = "^0.6.0" 62 | m2r = "^0.2.1" 63 | sphinx-autodoc-typehints = "^1.10.3" 64 | sphinx-autoapi = "^1.3.0" 65 | sphinx-paramlinks = "^0.4.1" 66 | pydata-sphinx-theme = "^0.2.2" 67 | 68 | [tool.poetry.extras] 69 | mlflow = ["mlflow"] 70 | search = ["optuna", "plotly", "sklearn"] 71 | 72 | [tool.black] 73 | skip-string-normalization = true # Don't switch to double quotes 74 | exclude = ''' 75 | /( 76 | \.git 77 | | \.tox 78 | | \.venv 79 | | build 80 | | dist 81 | )/ 82 | ''' 83 | 84 | [tool.isort] 85 | multi_line_output = 3 86 | include_trailing_comma = true 87 | force_grid_wrap = 0 88 | use_parentheses = true 89 | line_length = 88 90 | 91 | [build-system] 92 | requires = ["poetry>=1.1.0"] 93 | build-backend = "poetry.masonry.api" 94 | -------------------------------------------------------------------------------- /scripts/merge_target_and_gaps_preds.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | 18 | import argparse 19 | from pathlib import Path 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--target-pred', help='Target predictions', type=str) 25 | parser.add_argument('--gaps-pred', help='Gaps predictions', type=str) 26 | parser.add_argument('--output', help='Path to output', type=str) 27 | return parser.parse_args() 28 | 29 | 30 | def main(args): 31 | output_file_path = Path(args.output) 32 | if output_file_path.exists() and output_file_path.is_dir(): 33 | output_file_path = Path(output_file_path, 'predicted.prob') 34 | print('Output is a directory, saving to: {}'.format(output_file_path)) 35 | elif not output_file_path.exists(): 36 | if not output_file_path.parent.exists(): 37 | output_file_path.parent.mkdir(parents=True) 38 | f = output_file_path.open('w', encoding='utf8') 39 | 40 | with open(args.target_pred) as f_target, open(args.gaps_pred) as f_gaps: 41 | for line_target, line_gaps in zip(f_target, f_gaps): 42 | try: 43 | # labels are probs 44 | pred_target = list(map(float, line_target.split())) 45 | pred_gaps = list(map(float, line_gaps.split())) 46 | except ValueError: 47 | # labels are and tags 48 | pred_target = line_target.split() 49 | pred_gaps = line_gaps.split() 50 | new_preds = [] 51 | for i in range(len(pred_gaps)): 52 | new_preds.append(str(pred_gaps[i])) 53 | if i < len(pred_target): 54 | new_preds.append(str(pred_target[i])) 55 | f.write(' '.join(new_preds) + '\n') 56 | f.close() 57 | 58 | 59 | if __name__ == '__main__': 60 | args = parse_args() 61 | main(args) 62 | -------------------------------------------------------------------------------- /tests/mocks/mock_vocab.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | class MockVocab: 18 | vectors = None 19 | 20 | def __init__(self, dictionary): 21 | self.stoi = dictionary 22 | self.itos = {idx: token for token, idx in self.stoi.items()} 23 | self.length = max(self.itos.keys()) + 1 24 | 25 | def token_to_id(self, token): 26 | if token in self.stoi: 27 | return self.stoi[token] 28 | else: 29 | raise KeyError 30 | 31 | def __len__(self): 32 | return self.length 33 | -------------------------------------------------------------------------------- /tests/mocks/simple_model.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import torch 18 | 19 | from kiwi import constants as const 20 | from kiwi.metrics import LogMetric 21 | from kiwi.models.model import Model 22 | from kiwi.systems.decoders.quetch_deprecated import QUETCH 23 | from kiwi.systems.encoders.quetch import QUETCHEncoder 24 | 25 | 26 | @Model.register_subclass 27 | class SimpleModel(Model): 28 | 29 | target = const.TARGET_TAGS 30 | fieldset = QUETCH.fieldset 31 | metrics_ordering = QUETCH.metrics_ordering 32 | 33 | @staticmethod 34 | def default_features_embedder_class(): 35 | return QUETCHEncoder 36 | 37 | def _build_output_layer(self, vocabs, features_dim, config): 38 | self.output_layer = torch.nn.Sequential( 39 | torch.nn.Linear(features_dim, 1), torch.nn.Sigmoid() 40 | ) 41 | 42 | def forward(self, batch, *args, **kwargs): 43 | field_embeddings = { 44 | field: self.field_embedders[field](batch[field]) 45 | for field in (const.SOURCE, const.TARGET) 46 | } 47 | feature_embeddings = self.encoder(field_embeddings, batch) 48 | output = self.output_layer(feature_embeddings).squeeze() 49 | return {SimpleModel.target: output} 50 | 51 | def loss(self, model_out, batch): 52 | prediction = model_out[SimpleModel.target] 53 | target = getattr(batch, SimpleModel.target).float() 54 | return {const.LOSS: ((prediction - target) ** 2).mean()} 55 | 56 | def predict(self, batch, *args, **kwargs): 57 | predictions = self.forward(batch) 58 | return {SimpleModel.target: predictions[SimpleModel.target].tolist()} 59 | 60 | def metrics(self): 61 | return (LogMetric(log_targets=[(const.LOSS, const.LOSS)]),) 62 | -------------------------------------------------------------------------------- /tests/test_bert.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import shutil 18 | 19 | import pytest 20 | import yaml 21 | from transformers import BertConfig, BertModel 22 | from transformers.tokenization_bert import VOCAB_FILES_NAMES 23 | 24 | from conftest import check_computation 25 | from kiwi import constants as const 26 | 27 | bert_yaml = """ 28 | class_name: Bert 29 | 30 | num_data_workers: 0 31 | batch_size: 32 | train: 32 33 | valid: 32 34 | 35 | model: 36 | encoder: 37 | model_name: None 38 | interleave_input: false 39 | freeze: false 40 | use_mismatch_features: false 41 | use_predictor_features: false 42 | 43 | decoder: 44 | hidden_size: 16 45 | dropout: 0.0 46 | 47 | outputs: 48 | word_level: 49 | target: true 50 | gaps: true 51 | source: true 52 | class_weights: 53 | target_tags: 54 | BAD: 3.0 55 | gap_tags: 56 | BAD: 5.0 57 | source_tags: 58 | BAD: 3.0 59 | sentence_level: 60 | hter: true 61 | use_distribution: true 62 | binary: false 63 | sentence_loss_weight: 1 64 | 65 | tlm_outputs: 66 | fine_tune: false 67 | 68 | optimizer: 69 | class_name: noam 70 | learning_rate: 0.00001 71 | warmup_steps: 1000 72 | training_steps: 12000 73 | 74 | data_processing: 75 | share_input_fields_encoders: true 76 | 77 | """ 78 | 79 | 80 | @pytest.fixture 81 | def bert_config_dict(): 82 | return yaml.unsafe_load(bert_yaml) 83 | 84 | 85 | @pytest.fixture(scope='session') 86 | def bert_model_dir(model_dir): 87 | return model_dir.joinpath('bert/') 88 | 89 | 90 | @pytest.fixture(scope='function') 91 | def bert_model(): 92 | config = BertConfig( 93 | vocab_size=107, 94 | hidden_size=32, 95 | num_hidden_layers=5, 96 | num_attention_heads=4, 97 | intermediate_size=37, 98 | hidden_act='gelu', 99 | hidden_dropout_prob=0.1, 100 | attention_probs_dropout_prob=0.1, 101 | max_position_embeddings=512, 102 | type_vocab_size=16, 103 | is_decoder=False, 104 | initializer_range=0.02, 105 | ) 106 | return BertModel(config=config) 107 | 108 | 109 | def test_computation_target( 110 | tmp_path, 111 | bert_model, 112 | bert_model_dir, 113 | bert_config_dict, 114 | train_config, 115 | data_config, 116 | big_atol, 117 | ): 118 | train_config['run']['output_dir'] = tmp_path 119 | train_config['data'] = data_config 120 | train_config['system'] = bert_config_dict 121 | 122 | shutil.copy2(bert_model_dir / VOCAB_FILES_NAMES['vocab_file'], tmp_path) 123 | bert_model.save_pretrained(tmp_path) 124 | train_config['system']['model']['encoder']['model_name'] = str(tmp_path) 125 | 126 | check_computation( 127 | train_config, 128 | tmp_path, 129 | output_name=const.TARGET_TAGS, 130 | expected_avg_probs=0.550805, 131 | atol=big_atol, 132 | ) 133 | 134 | 135 | if __name__ == '__main__': # pragma: no cover 136 | pytest.main([__file__]) # pragma: no cover 137 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import pytest 18 | 19 | from kiwi.lib.utils import arguments_to_configuration, file_to_configuration 20 | 21 | 22 | @pytest.fixture 23 | def config_path(tmp_path): 24 | return tmp_path / 'config.yaml' 25 | 26 | 27 | @pytest.fixture 28 | def config(): 29 | return """ 30 | class_name: Bert 31 | 32 | num_data_workers: 0 33 | batch_size: 34 | train: 32 35 | valid: 32 36 | 37 | model: 38 | encoder: 39 | model_name: None 40 | interleave_input: false 41 | freeze: false 42 | use_mismatch_features: false 43 | use_predictor_features: false 44 | encode_source: false 45 | 46 | """ 47 | 48 | 49 | def test_file_reading_to_configuration(config, config_path): 50 | """Tests if files are being correctly handed to hydra for 51 | composition. 52 | """ 53 | config_path.write_text(config) 54 | config_dict = file_to_configuration(config_path) 55 | assert isinstance(config_dict, dict) 56 | assert 'num_data_workers' in config_dict 57 | 58 | 59 | def test_hydra_state_hadnling(config, config_path): 60 | """Tests if hydra global state is being handled correctly. 61 | If not, kiwi will not allow a config to be ran twice.""" 62 | 63 | config_path.write_text(config) 64 | config_dict = file_to_configuration(config_path) 65 | config_dict = file_to_configuration(config_path) 66 | 67 | 68 | def test_arguments_to_configuration(config, config_path): 69 | """Tests if configuration handling and overwrites are working""" 70 | 71 | config_path.write_text(config) 72 | config_dict = arguments_to_configuration( 73 | {'CONFIG_FILE': config_path, 'OVERWRITES': ['class_name=XLMR']} 74 | ) 75 | 76 | assert 'num_data_workers' in config_dict 77 | assert 'model' in config_dict 78 | assert 'freeze' in config_dict['model']['encoder'] 79 | assert config_dict['class_name'] == 'XLMR' 80 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | # 18 | # This program is free software: you can redistribute it and/or modify 19 | # it under the terms of the GNU Affero General Public License as published 20 | # by the Free Software Foundation, either version 3 of the License, or 21 | # (at your option) any later version. 22 | # 23 | # This program is distributed in the hope that it will be useful, 24 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 25 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 26 | # GNU Affero General Public License for more details. 27 | # 28 | # You should have received a copy of the GNU Affero General Public License 29 | # along with this program. If not, see . 30 | # 31 | 32 | import numpy as np 33 | import pytest 34 | 35 | from kiwi.metrics.functions import ( 36 | confusion_matrix, 37 | f1_product, 38 | fscore, 39 | matthews_correlation_coefficient, 40 | ) 41 | 42 | 43 | @pytest.fixture 44 | def labels(): 45 | n_class = 2 46 | y_gold = np.array( 47 | [ 48 | np.array([1, 1, 0, 1]), 49 | np.array([1, 1, 0, 1, 1, 1, 1, 0]), 50 | np.array([1, 1, 1, 0, 1, 1, 0, 0, 0, 1]), 51 | ] 52 | ) 53 | y_hat = np.array( 54 | [ 55 | np.array([1, 1, 0, 0]), 56 | np.array([1, 1, 0, 1, 1, 1, 1, 0]), 57 | np.array([1, 1, 1, 0, 1, 1, 1, 0, 0, 0]), 58 | ] 59 | ) 60 | return y_gold, y_hat, n_class 61 | 62 | 63 | def test_fscore(labels, atol): 64 | y_gold, y_hat, n_class = labels 65 | cnfm = confusion_matrix(y_hat, y_gold, n_class) 66 | f1_orig_prod_micro = f1_product(y_hat, y_gold) 67 | f1_prod_macro = 0 68 | f1_orig_prod_macro = 0 69 | tp, tn, fp, fn = 0, 0, 0, 0 70 | for ys_hat, ys_gold in zip(y_hat, y_gold): 71 | ctp = np.sum((ys_hat == 1) & (ys_gold == 1)) 72 | ctn = np.sum((ys_hat == 0) & (ys_gold == 0)) 73 | cfp = np.sum((ys_hat == 1) & (ys_gold == 0)) 74 | cfn = np.sum((ys_hat == 0) & (ys_gold == 1)) 75 | tp += ctp 76 | tn += ctn 77 | fp += cfp 78 | fn += cfn 79 | f_ok = fscore(ctp, cfp, cfn) 80 | f_bad = fscore(ctn, cfp, cfn) 81 | f1_prod_macro += f_ok * f_bad 82 | f1_orig_prod_macro += f1_product(ys_hat, ys_gold) 83 | 84 | assert tn == cnfm[0, 0] 85 | assert fp == cnfm[0, 1] 86 | assert fn == cnfm[1, 0] 87 | assert tp == cnfm[1, 1] 88 | 89 | f_ok = fscore(tp, fp, fn) 90 | f_bad = fscore(tn, fp, fn) 91 | f1_prod_micro = f_ok * f_bad 92 | f1_prod_macro = f1_prod_macro / y_gold.shape[0] 93 | f1_orig_prod_macro = f1_orig_prod_macro / y_gold.shape[0] 94 | 95 | np.testing.assert_allclose(f1_prod_micro, f1_orig_prod_micro, atol=atol) 96 | np.testing.assert_allclose(f1_prod_macro, f1_orig_prod_macro, atol=atol) 97 | 98 | 99 | def test_matthews(labels, atol): 100 | y_gold, y_hat, n_class = labels 101 | matthews = matthews_correlation_coefficient(y_hat, y_gold) 102 | np.testing.assert_allclose(matthews, 0.70082555, atol=atol) 103 | n1 = matthews_correlation_coefficient(y_hat, 1 - y_gold) 104 | n2 = matthews_correlation_coefficient(1 - y_hat, y_gold) 105 | nn = matthews_correlation_coefficient(1 - y_hat, 1 - y_gold) 106 | assert n1 == n2 107 | assert nn == matthews 108 | assert n1 == -matthews 109 | 110 | 111 | if __name__ == '__main__': # pragma: no cover 112 | pytest.main([__file__]) # pragma: no cover 113 | -------------------------------------------------------------------------------- /tests/test_nuqe.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import numpy as np 18 | import pytest 19 | 20 | from conftest import check_computation 21 | from kiwi import constants as const 22 | from kiwi.lib import train 23 | from kiwi.lib.utils import save_config_to_file 24 | 25 | 26 | @pytest.fixture 27 | def output_source_config(nuqe_config_dict): 28 | nuqe_config_dict['model']['outputs']['word_level']['source'] = True 29 | return nuqe_config_dict 30 | 31 | 32 | @pytest.fixture 33 | def output_gaps_config(nuqe_config_dict): 34 | nuqe_config_dict['model']['outputs']['word_level']['gaps'] = True 35 | return nuqe_config_dict 36 | 37 | 38 | @pytest.fixture 39 | def output_targetgaps_config(nuqe_config_dict): 40 | nuqe_config_dict['model']['outputs']['word_level']['target'] = True 41 | nuqe_config_dict['model']['outputs']['word_level']['gaps'] = True 42 | return nuqe_config_dict 43 | 44 | 45 | def test_computation_target( 46 | tmp_path, output_target_config, train_config, data_config, big_atol 47 | ): 48 | train_config['data'] = data_config 49 | train_config['system'] = output_target_config 50 | check_computation( 51 | train_config, 52 | tmp_path, 53 | output_name=const.TARGET_TAGS, 54 | expected_avg_probs=0.498354, 55 | atol=big_atol, 56 | ) 57 | 58 | 59 | def test_computation_gaps( 60 | tmp_path, output_gaps_config, train_config, data_config, atol 61 | ): 62 | train_config['data'] = data_config 63 | train_config['system'] = output_gaps_config 64 | check_computation( 65 | train_config, 66 | tmp_path, 67 | output_name=const.GAP_TAGS, 68 | expected_avg_probs=0.316064, 69 | atol=atol, 70 | ) 71 | 72 | 73 | def test_computation_source( 74 | tmp_path, output_source_config, train_config, data_config, big_atol 75 | ): 76 | train_config['data'] = data_config 77 | train_config['system'] = output_source_config 78 | check_computation( 79 | train_config, 80 | tmp_path, 81 | output_name=const.SOURCE_TAGS, 82 | expected_avg_probs=0.486522, 83 | atol=big_atol, 84 | ) 85 | 86 | 87 | def test_computation_targetgaps( 88 | tmp_path, output_targetgaps_config, train_config, data_config, big_atol 89 | ): 90 | train_config['data'] = data_config 91 | train_config['system'] = output_targetgaps_config 92 | check_computation( 93 | train_config, 94 | tmp_path, 95 | output_name=const.TARGET_TAGS, 96 | expected_avg_probs=0.507699, 97 | atol=big_atol, 98 | ) 99 | 100 | 101 | def test_api(tmp_path, output_target_config, train_config, data_config, big_atol): 102 | from kiwi import train_from_file, load_system 103 | 104 | train_config['data'] = data_config 105 | train_config['system'] = output_target_config 106 | 107 | config_file = tmp_path / 'config.yaml' 108 | save_config_to_file(train.Configuration(**train_config), config_file) 109 | 110 | train_run_info = train_from_file(config_file) 111 | 112 | runner = load_system(train_run_info.best_model_path) 113 | 114 | source = open(data_config['test']['input']['source']).readlines() 115 | target = open(data_config['test']['input']['target']).readlines() 116 | alignments = open(data_config['test']['input']['alignments']).readlines() 117 | 118 | predictions = runner.predict( 119 | source=source, 120 | target=target, 121 | alignments=alignments, 122 | batch_size=train_config['system']['batch_size'], 123 | ) 124 | 125 | target_tags_probabilities = predictions.target_tags_BAD_probabilities 126 | avg_of_avgs = np.mean(list(map(np.mean, target_tags_probabilities))) 127 | max_prob = max(map(max, target_tags_probabilities)) 128 | min_prob = min(map(min, target_tags_probabilities)) 129 | np.testing.assert_allclose(avg_of_avgs, 0.498287, atol=big_atol) 130 | assert 0 <= min_prob <= avg_of_avgs <= max_prob <= 1 131 | 132 | assert len(predictions.target_tags_labels) == len(target) 133 | 134 | 135 | if __name__ == '__main__': # pragma: no cover 136 | pytest.main([__file__]) # pragma: no cover 137 | -------------------------------------------------------------------------------- /tests/test_predict_and_evaluate.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import pytest 18 | 19 | from kiwi import load_system 20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset 21 | from kiwi.lib import predict 22 | 23 | 24 | def test_predicting_and_evaluating(tmp_path, data_config, model_dir, atol): 25 | load_model = model_dir / 'nuqe.ckpt' 26 | 27 | predicting_config = predict.Configuration( 28 | run=dict(seed=42, output_dir=tmp_path, predict_on_data_partition='valid'), 29 | data=WMTQEDataset.Config(**data_config), 30 | system=dict(load=load_model), 31 | use_gpu=False, 32 | ) 33 | 34 | predictions, metrics = predict.run(predicting_config, tmp_path) 35 | assert set(predictions.keys()) == { 36 | 'target_tags', 37 | 'target_tags_labels', 38 | 'gap_tags', 39 | 'gap_tags_labels', 40 | 'sentence_scores', 41 | 'sentence_scores_extras', 42 | 'targetgaps_tags', 43 | 'targetgaps_tags_labels', 44 | } 45 | assert ( 46 | abs(metrics.word_scores['targetgaps_tags']['F1_Mult'][0] - 0.1937725747540829) 47 | < 0.1 48 | ) 49 | assert ( 50 | abs(metrics.word_scores['target_tags']['F1_Mult'][0] - 0.053539393196874105) 51 | < 0.1 52 | ) 53 | assert abs(metrics.sentence_scores['scoring'][0][1] - -0.13380020964645983) < 0.1 54 | 55 | 56 | def test_runner(tmp_path, data_config, model_dir, atol): 57 | load_model = model_dir / 'nuqe.ckpt' 58 | runner = load_system(load_model) 59 | 60 | data_config = WMTQEDataset.Config(**data_config) 61 | dataset = WMTQEDataset.build(data_config, valid=True) 62 | 63 | predictions = runner.predict( 64 | source=dataset['source'], 65 | target=dataset['target'], 66 | alignments=dataset['alignments'], 67 | ) 68 | 69 | target_lengths = [len(s.split()) for s in dataset['target']] 70 | predicted_lengths = [len(s) for s in predictions.target_tags_BAD_probabilities] 71 | predicted_labels_lengths = [len(s) for s in predictions.target_tags_labels] 72 | 73 | assert target_lengths == predicted_lengths == predicted_labels_lengths 74 | 75 | predicted_gap_lengths = [len(s) - 1 for s in predictions.gap_tags_BAD_probabilities] 76 | predicted_gap_labels_len = [len(s) - 1 for s in predictions.gap_tags_labels] 77 | 78 | assert target_lengths == predicted_gap_lengths == predicted_gap_labels_len 79 | 80 | assert len(dataset['target']) == len(predictions.sentences_hter) 81 | 82 | 83 | def test_predict_with_empty_sentences(data_config, model_dir): 84 | load_model = model_dir / 'nuqe.ckpt' 85 | runner = load_system(load_model) 86 | 87 | data_config = WMTQEDataset.Config(**data_config) 88 | dataset = WMTQEDataset.build(data_config, valid=True) 89 | 90 | test_data = dict( 91 | source=dataset['source'], 92 | target=dataset['target'], 93 | alignments=dataset['alignments'], 94 | ) 95 | 96 | blank_indices = [1, 3, -1] 97 | 98 | for field in test_data: 99 | for idx in reversed(blank_indices): 100 | test_data[field].insert(idx, '') 101 | 102 | assert len(test_data['source']) == len(dataset['source']) + len(blank_indices) 103 | assert len(test_data['target']) == len(dataset['target']) + len(blank_indices) 104 | 105 | predictions = runner.predict(**test_data) 106 | 107 | assert len(predictions.target_tags_labels) == len(test_data['target']) 108 | assert len(predictions.target_tags_BAD_probabilities) == len(test_data['target']) 109 | 110 | 111 | def test_predict_with_all_empty_sentences(data_config, model_dir): 112 | load_model = model_dir / 'nuqe.ckpt' 113 | runner = load_system(load_model) 114 | 115 | # All empty 116 | test_data = dict(source=[''] * 10, target=[''] * 10, alignments=[''] * 10,) 117 | predictions = runner.predict(**test_data) 118 | assert all(prediction is None for prediction in vars(predictions).values()) 119 | 120 | 121 | def test_predict_with_one_side_all_empty_sentence(data_config, model_dir): 122 | load_model = model_dir / 'nuqe.ckpt' 123 | runner = load_system(load_model) 124 | 125 | # One side all empty 126 | test_data = dict(source=['AB'] * 10, target=[''] * 10, alignments=[''] * 10,) 127 | with pytest.raises(ValueError, match='Received empty'): 128 | runner.predict(**test_data) 129 | 130 | # Other side 131 | test_data = dict(source=[''] * 10, target=['AB'] * 10, alignments=[''] * 10,) 132 | with pytest.raises(ValueError, match='Received empty'): 133 | runner.predict(**test_data) 134 | -------------------------------------------------------------------------------- /tests/test_predictor.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import numpy as np 18 | import pytest 19 | import yaml 20 | 21 | from kiwi.lib import pretrain 22 | 23 | predictor_yaml = """ 24 | class_name: Predictor 25 | 26 | num_data_workers: 0 27 | batch_size: 28 | train: 32 29 | valid: 32 30 | 31 | model: 32 | encoder: 33 | hidden_size: 400 34 | rnn_layers: 2 35 | embeddings: 36 | source: 37 | dim: 200 38 | target: 39 | dim: 200 40 | out_embeddings_dim: 200 41 | share_embeddings: false 42 | dropout: 0.5 43 | use_mismatch_features: false 44 | 45 | optimizer: 46 | class_name: adam 47 | learning_rate: 0.001 48 | learning_rate_decay: 0.6 49 | learning_rate_decay_start: 2 50 | 51 | data_processing: 52 | vocab: 53 | min_frequency: 1 54 | max_size: 60_000 55 | """ 56 | 57 | 58 | @pytest.fixture 59 | def predictor_config_dict(): 60 | return yaml.unsafe_load(predictor_yaml) 61 | 62 | 63 | def test_pretrain_predictor( 64 | tmp_path, predictor_config_dict, pretrain_config, data_config, extra_big_atol 65 | ): 66 | pretrain_config['run']['output_dir'] = tmp_path 67 | pretrain_config['data'] = data_config 68 | pretrain_config['system'] = predictor_config_dict 69 | 70 | train_info = pretrain.pretrain_from_configuration(pretrain_config) 71 | 72 | stats = train_info.best_metrics 73 | np.testing.assert_allclose(stats['target_PERP'], 838.528486, atol=extra_big_atol) 74 | np.testing.assert_allclose( 75 | stats['val_target_PERP'], 501.516467, atol=extra_big_atol 76 | ) 77 | 78 | # Testing predictor with pickled data 79 | pretrain_config['system']['load'] = train_info.best_model_path 80 | 81 | train_info_from_loaded = pretrain.pretrain_from_configuration(pretrain_config) 82 | 83 | stats = train_info_from_loaded.best_metrics 84 | np.testing.assert_allclose( 85 | stats['target_PERP'], 166.4964834955168, atol=extra_big_atol 86 | ) 87 | np.testing.assert_allclose( 88 | stats['val_target_PERP'], 333.688725, atol=extra_big_atol 89 | ) 90 | 91 | 92 | if __name__ == '__main__': # pragma: no cover 93 | 94 | pytest.main([__file__]) # pragma: no cover 95 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import pytest 18 | import torch 19 | 20 | from kiwi.utils.io import generate_slug, load_torch_file 21 | 22 | 23 | def test_load_torch_file(model_dir): 24 | load_torch_file(model_dir / 'nuqe.ckpt') 25 | # There's no CUDA: 26 | # if not torch.cuda.is_available(): 27 | # with pytest.raises(RuntimeError, match='No CUDA GPUs are available'): 28 | # load_torch_file( 29 | # model_dir / 'nuqe.ckpt', 30 | # map_location=lambda storage, loc: storage.cuda(0), 31 | # ) 32 | # And this file does not exist: 33 | with pytest.raises(ValueError): 34 | load_torch_file(model_dir / 'nonexistent.torch') 35 | 36 | 37 | def test_generate_slug(): 38 | assert generate_slug('Some Random Text!') == 'some-random-text' 39 | -------------------------------------------------------------------------------- /tests/test_xlm.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import pytest 18 | import yaml 19 | from transformers import XLMConfig, XLMModel, XLMTokenizer 20 | 21 | from conftest import check_computation 22 | from kiwi import constants as const 23 | 24 | xlm_yaml = """ 25 | class_name: XLM 26 | 27 | num_data_workers: 0 28 | batch_size: 29 | train: 32 30 | valid: 32 31 | 32 | model: 33 | encoder: 34 | model_name: None 35 | interleave_input: false 36 | freeze: false 37 | 38 | decoder: 39 | hidden_size: 16 40 | dropout: 0.0 41 | 42 | outputs: 43 | word_level: 44 | target: true 45 | gaps: true 46 | source: true 47 | class_weights: 48 | target_tags: 49 | BAD: 3.0 50 | gap_tags: 51 | BAD: 5.0 52 | source_tags: 53 | BAD: 3.0 54 | sentence_level: 55 | hter: true 56 | use_distribution: true 57 | binary: false 58 | sentence_loss_weight: 1 59 | 60 | tlm_outputs: 61 | fine_tune: false 62 | 63 | optimizer: 64 | class_name: adamw 65 | learning_rate: 0.00001 66 | warmup_steps: 0.1 67 | training_steps: 12000 68 | 69 | data_processing: 70 | share_input_fields_encoders: true 71 | 72 | """ 73 | 74 | 75 | @pytest.fixture 76 | def xlm_config_dict(): 77 | return yaml.unsafe_load(xlm_yaml) 78 | 79 | 80 | @pytest.fixture(scope='session') 81 | def xlm_model_dir(model_dir): 82 | return model_dir.joinpath('xlm/') 83 | 84 | 85 | @pytest.fixture(scope='function') 86 | def xlm_model(): 87 | config = XLMConfig( 88 | vocab_size=93000, 89 | emb_dim=32, 90 | n_layers=5, 91 | n_heads=4, 92 | dropout=0.1, 93 | max_position_embeddings=512, 94 | lang2id={ 95 | "ar": 0, 96 | "bg": 1, 97 | "de": 2, 98 | "el": 3, 99 | "en": 4, 100 | "es": 5, 101 | "fr": 6, 102 | "hi": 7, 103 | "ru": 8, 104 | "sw": 9, 105 | "th": 10, 106 | "tr": 11, 107 | "ur": 12, 108 | "vi": 13, 109 | "zh": 14, 110 | }, 111 | ) 112 | return XLMModel(config=config) 113 | 114 | 115 | @pytest.fixture() 116 | def xlm_tokenizer(): 117 | return XLMTokenizer.from_pretrained('xlm-mlm-tlm-xnli15-1024') 118 | 119 | 120 | def test_computation_target( 121 | tmp_path, 122 | xlm_model, 123 | xlm_tokenizer, 124 | xlm_model_dir, 125 | xlm_config_dict, 126 | train_config, 127 | data_config, 128 | big_atol, 129 | ): 130 | train_config['run']['output_dir'] = tmp_path 131 | train_config['data'] = data_config 132 | train_config['system'] = xlm_config_dict 133 | 134 | xlm_model.save_pretrained(tmp_path) 135 | xlm_tokenizer.save_pretrained(tmp_path) 136 | train_config['system']['model']['encoder']['model_name'] = str(tmp_path) 137 | 138 | check_computation( 139 | train_config, 140 | tmp_path, 141 | output_name=const.TARGET_TAGS, 142 | expected_avg_probs=0.410072, 143 | atol=big_atol, 144 | ) 145 | 146 | 147 | if __name__ == '__main__': # pragma: no cover 148 | pytest.main([__file__]) # pragma: no cover 149 | -------------------------------------------------------------------------------- /tests/test_xlmr.py: -------------------------------------------------------------------------------- 1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation 2 | # Copyright (C) 2020 Unbabel 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published 6 | # by the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | # 17 | import pytest 18 | import yaml 19 | from transformers import XLMRobertaConfig, XLMRobertaModel 20 | 21 | from conftest import check_computation 22 | from kiwi import constants as const 23 | 24 | xlmr_yaml = """ 25 | class_name: XLMRoberta 26 | 27 | num_data_workers: 0 28 | batch_size: 29 | train: 16 30 | valid: 16 31 | 32 | model: 33 | encoder: 34 | model_name: None 35 | interleave_input: false 36 | freeze: false 37 | pooling: mixed 38 | 39 | decoder: 40 | hidden_size: 16 41 | dropout: 0.0 42 | 43 | outputs: 44 | word_level: 45 | target: true 46 | gaps: true 47 | source: true 48 | class_weights: 49 | target_tags: 50 | BAD: 3.0 51 | gap_tags: 52 | BAD: 5.0 53 | source_tags: 54 | BAD: 3.0 55 | sentence_level: 56 | hter: true 57 | use_distribution: true 58 | binary: false 59 | sentence_loss_weight: 1 60 | 61 | tlm_outputs: 62 | fine_tune: false 63 | 64 | optimizer: 65 | class_name: adamw 66 | learning_rate: 0.00001 67 | warmup_steps: 0.1 68 | 69 | data_processing: 70 | share_input_fields_encoders: true 71 | 72 | """ 73 | 74 | 75 | @pytest.fixture 76 | def xlmr_config_dict(): 77 | return yaml.unsafe_load(xlmr_yaml) 78 | 79 | 80 | @pytest.fixture(scope='session') 81 | def xlmr_model_dir(model_dir): 82 | return model_dir.joinpath('xlmr/') 83 | 84 | 85 | @pytest.fixture(scope='function') 86 | def xlmr_model(): 87 | config = XLMRobertaConfig( 88 | vocab_size=251000, 89 | hidden_size=32, 90 | num_hidden_layers=5, 91 | num_attention_heads=4, 92 | intermediate_size=37, 93 | hidden_act='gelu', 94 | hidden_dropout_prob=0.1, 95 | attention_probs_dropout_prob=0.1, 96 | max_position_embeddings=256, 97 | type_vocab_size=2, 98 | is_decoder=False, 99 | initializer_range=0.02, 100 | ) 101 | return XLMRobertaModel(config=config) 102 | 103 | 104 | def test_computation_target( 105 | tmp_path, 106 | xlmr_model, 107 | xlmr_model_dir, 108 | xlmr_config_dict, 109 | train_config, 110 | data_config, 111 | big_atol, 112 | ): 113 | train_config['run']['output_dir'] = tmp_path 114 | train_config['data'] = data_config 115 | train_config['system'] = xlmr_config_dict 116 | 117 | xlmr_model.save_pretrained(tmp_path) 118 | train_config['system']['model']['encoder']['model_name'] = str(tmp_path) 119 | 120 | # When using `adamw` optimizer and the `optimizer.training_steps` are not set: 121 | with pytest.raises(ValueError): 122 | check_computation( 123 | train_config, 124 | tmp_path, 125 | output_name=const.TARGET_TAGS, 126 | expected_avg_probs=0.383413, 127 | atol=big_atol, 128 | ) 129 | 130 | # Now training will run: 131 | train_config['system']['optimizer']['training_steps'] = 10 132 | check_computation( 133 | train_config, 134 | tmp_path, 135 | output_name=const.TARGET_TAGS, 136 | expected_avg_probs=0.383413, 137 | atol=big_atol, 138 | ) 139 | 140 | 141 | if __name__ == '__main__': # pragma: no cover 142 | pytest.main([__file__]) # pragma: no cover 143 | -------------------------------------------------------------------------------- /tests/toy-data/WMT18/word_level/en_de.nmt/dev.hter: -------------------------------------------------------------------------------- 1 | 0.322581 2 | 0.000000 3 | 0.064516 4 | 0.263158 5 | 0.000000 6 | 0.312500 7 | 0.375000 8 | 0.727273 9 | 0.161290 10 | 0.000000 11 | 0.115385 12 | 0.192308 13 | 0.043478 14 | 0.153846 15 | 0.090909 16 | 0.187500 17 | 0.705882 18 | 0.000000 19 | 0.125000 20 | 0.388889 21 | 0.000000 22 | 0.045455 23 | 0.333333 24 | 0.722222 25 | 0.285714 26 | 0.620690 27 | 0.071429 28 | 0.000000 29 | 0.217391 30 | 0.000000 31 | 0.125000 32 | 0.000000 33 | 0.166667 34 | 0.000000 35 | 0.368421 36 | 0.076923 37 | 0.181818 38 | 0.000000 39 | 0.133333 40 | 0.052632 41 | 0.000000 42 | 0.600000 43 | 0.642857 44 | 0.125000 45 | 0.071429 46 | 0.083333 47 | 0.071429 48 | 0.000000 49 | 0.476190 50 | 0.000000 51 | -------------------------------------------------------------------------------- /tests/toy-data/WMT18/word_level/en_de.nmt/dev.src: -------------------------------------------------------------------------------- 1 | to add or remove pixels when resizing so the image retains approximately the same appearance at a different size , select Resample Image . 2 | to update all assignments in the current document , choose Update All Assignments from the Assignments panel menu . 3 | in the Options tab , click the Custom button and enter lower values for Error Correction Level and Y / X Ratio . 4 | for example , you could create a document containing a car that moves across the Stage . 5 | in the New From Template dialog box , locate and select a template , and click New . 6 | make sure that you obtained the security settings file from a source that you trust . 7 | makes a rectangular selection ( or a square , when used with the Shift key ) . 8 | drag diagonally from the corner where you want the graph to begin to the opposite corner . 9 | enter a value from -100 % to 100 % to specify the percentage by which to decrease or increase the color or the spot-color tint . 10 | you can enable the Contribute publishing server using this dialog box . 11 | you can add any web page - not just pages in websites or entries in blogs that you 're connected to - to your bookmarks list . 12 | use the Export commands on the File menu to export all or part of an InDesign document to other formats . 13 | if appropriate , click Browse to navigate to the location in which you want the downloads to be placed . 14 | for Point / Pica Size , choose from the following options : 15 | to check the character code currently selected , select Shift JIS , JIS , Kuten , or Unicode , and display the code system . 16 | a digital signature , like a conventional handwritten signature , identifies the person signing a document . 17 | this option converts a complete 360 x 180 degree spherical panorama to a 3D layer . 18 | specifies that CSS styles appear in the Style menu . 19 | for example , suppose you want to update the content of a formatting table in a monthly magazine . 20 | in fact , most of the features you see in InDesign are provided by plug ‑ ins . 21 | the default is ^ t , which tells InDesign to insert a tab . 22 | select the new keyframe and drag one of the Learning Interaction movie clips from the Library panel to the Stage . 23 | sets the type of currency , such as Euros , Dollars , or Yen . 24 | enter Please indicate your level of satisfaction for the text parameter . 25 | substitutes the standard glyph with the jp78 ‑ variant glyph . 26 | resampling adds pixels to or subtracts pixels from a resized bitmap to match the appearance of the original bitmap as closely as possible . 27 | however , this feature can also be used to add malicious data into a PDF . 28 | to change the text size , click the Decrease Text Size button or the Increase Text Size button . 29 | please remember that existing artwork or images that you may want to include in your project may be protected under copyright law . 30 | when the Identity Setup dialog box appears , enter the appropriate information about yourself , and click Complete . 31 | the Properties toolbar is different in that it doesn 't contain tools and can 't be customized to hide options . 32 | for this reason , Adobe ® Flash ® Player includes a set of security rules and controls to safeguard the user , website owner , and content developer . 33 | does not have a menu bar . 34 | click the Start button and choose Settings > Printers And Faxes . 35 | it can also represent a spatial vector in physics , which has a direction and a magnitude . 36 | switch between the snapshots to find the settings you like best . 37 | then drag to adjust . 38 | choose Always from the Maximize PSD and PSB File Compatibility menu . 39 | signatures that certify an Adobe ® PDF are called certifying signatures . 40 | select an option under State , and then specify a label or icon option : 41 | to create a single-page document , choose File > Create PDF > From File . 42 | the file contains two parameters , monthNames and dayNames . 43 | contains information about the draft , including the blog title , the blog post title , and the associated tags . 44 | whether the profile is included or not is determined by the Profile Inclusion Policy . 45 | removing certain elements can seriously affect the functionality of the PDF . 46 | after the sound data starts loading , this code calls the snd.play ( ) method and stores the resulting SoundChannel object in the channel variable . 47 | after a web page completes its workflow or you have completed editing a blog , you can publish it to your website or blog from Contribute . 48 | select Window > Other Panels > Strings , and click Import XML . 49 | it performs as expected on paths that are oval , square , rectangular or otherwise irregularly shaped . 50 | press Enter or Return to begin a new paragraph . 51 | -------------------------------------------------------------------------------- /tests/toy-data/WMT18/word_level/en_de.nmt/dev.src-mt.alignments: -------------------------------------------------------------------------------- 1 | 0-8 1-14 2-15 3-17 4-13 5-9 6-8 6-10 7-19 8-11 9-12 10-13 11-22 12-20 13-23 17-28 18-24 18-29 19-18 19-25 20-0 21-4 22-21 23-30 2 | 0-0 0-1 0-8 1-7 2-2 2-3 3-4 4-4 5-4 6-5 7-6 8-9 9-10 9-11 9-12 10-21 11-19 12-15 14-16 15-18 15-19 16-20 17-13 18-22 3 | 1-3 2-5 2-6 3-4 4-5 5-0 5-8 6-9 7-12 8-10 8-11 8-13 9-14 10-15 10-16 11-15 12-17 13-18 14-19 15-20 15-21 16-21 17-22 18-24 19-25 20-26 21-28 22-29 4 | 0-2 1-2 2-2 3-0 3-1 4-1 5-5 6-3 7-4 8-6 8-7 9-8 10-9 12-10 13-10 14-11 15-12 16-14 5 | 0-2 1-1 2-5 3-4 3-6 3-8 4-5 4-7 5-3 6-3 9-12 10-0 10-11 11-9 12-10 14-12 15-13 15-15 16-16 16-17 16-19 17-18 6 | 0-0 1-2 2-3 2-4 3-5 4-15 5-6 5-9 6-8 6-10 7-10 8-7 9-11 10-12 11-14 14-15 15-16 7 | 0-0 1-1 2-2 3-3 4-4 5-5 6-6 7-7 8-8 9-9 10-14 10-15 11-11 12-12 13-13 14-13 15-16 16-17 8 | 0-0 0-1 1-1 3-2 3-3 5-6 5-7 5-8 8-9 8-10 9-11 11-16 12-17 13-13 14-14 15-15 16-18 9 | 0-0 0-1 1-2 2-3 3-4 4-5 5-6 7-8 8-9 9-11 9-12 10-15 11-13 12-14 14-16 15-17 16-24 16-25 17-21 18-28 19-18 19-19 19-22 20-20 20-23 21-26 23-27 24-28 25-29 10 | 0-0 1-1 2-7 3-2 4-3 5-3 6-4 7-4 8-5 9-6 10-6 11-8 11 | 0-0 1-1 2-5 3-2 3-3 4-4 5-4 7-7 8-8 9-10 10-11 11-12 12-13 13-14 14-15 15-16 16-17 17-20 18-19 18-22 19-18 19-21 20-24 21-23 22-24 23-25 24-26 25-26 26-27 12 | 0-0 0-1 1-2 2-5 2-6 3-3 5-8 6-7 6-8 6-11 7-9 7-10 7-13 8-12 8-14 9-26 11-17 12-16 12-19 13-19 14-18 14-20 15-21 17-25 18-23 19-24 20-27 13 | 0-2 1-2 2-2 2-6 3-0 3-1 3-3 4-5 6-13 7-8 7-9 7-12 8-10 9-11 10-15 11-14 11-16 12-14 13-21 14-17 15-17 15-18 16-21 17-20 18-19 19-22 14 | 0-2 1-4 2-5 4-3 5-6 6-7 8-9 9-10 10-11 11-12 15 | 0-0 0-5 1-6 2-1 3-4 4-4 5-2 6-3 7-7 8-8 8-9 9-12 10-15 11-16 12-15 13-16 14-18 15-19 16-21 17-23 19-25 20-26 21-27 21-28 22-29 23-28 23-30 24-31 16 | 0-0 1-1 2-2 3-3 4-4 5-5 6-6 7-7 8-8 9-9 10-10 11-11 12-11 13-12 14-13 15-14 16-15 17 | 0-0 0-1 1-2 2-10 3-4 4-3 5-3 6-4 7-5 8-6 9-6 10-7 12-8 13-9 14-9 15-11 18 | 0-1 0-2 1-3 1-4 2-5 3-5 4-11 4-12 5-6 6-8 7-8 7-9 7-10 8-7 9-13 19 | 0-0 1-0 2-1 3-0 4-2 5-3 6-3 7-13 8-4 9-5 10-5 11-6 11-10 12-9 13-8 14-9 14-11 15-10 16-12 17-12 18-14 20 | 3-0 3-1 4-2 5-2 6-6 8-5 9-3 10-4 11-7 12-10 13-8 14-9 15-10 16-10 17-11 21 | 0-0 1-1 2-2 3-3 4-4 5-5 6-6 7-8 7-9 8-7 9-10 10-13 11-11 11-12 12-12 13-14 22 | 0-0 0-1 1-2 2-3 3-4 4-6 5-7 5-8 6-6 7-9 7-10 8-10 9-9 10-11 11-12 12-13 13-12 14-13 15-16 16-14 16-15 16-17 17-18 18-19 19-20 20-21 23 | 0-0 1-1 2-2 3-3 4-4 5-6 6-7 6-8 7-7 8-7 9-8 10-9 11-10 12-11 12-12 13-12 14-13 24 | 0-0 1-1 2-3 3-8 4-5 5-6 6-7 7-2 8-9 8-10 9-11 10-4 11-13 25 | 0-0 1-1 2-1 3-1 4-2 4-3 5-4 6-3 7-4 8-5 9-5 10-6 26 | 0-0 0-1 1-2 1-4 2-3 3-4 4-5 5-6 6-7 7-8 8-9 9-11 10-12 11-13 11-14 11-15 12-16 12-17 12-18 13-18 14-19 15-20 16-20 17-21 18-22 19-25 20-23 20-24 21-25 22-26 23-28 27 | 0-5 2-1 3-0 3-2 4-3 5-6 10-7 11-8 12-9 13-10 14-11 14-12 15-13 28 | 0-0 0-1 0-5 1-4 2-2 3-3 4-3 5-6 6-7 6-8 6-9 7-10 8-12 8-13 8-14 9-13 10-13 11-11 11-15 12-16 13-18 14-19 14-20 14-21 14-22 15-21 16-21 17-19 17-24 18-23 29 | 0-0 0-1 1-0 2-2 2-3 3-4 4-4 5-6 6-5 6-7 7-8 7-9 8-10 9-11 10-16 12-15 13-12 14-13 15-14 16-18 17-18 18-20 18-21 19-19 20-19 21-19 22-22 30 | 0-0 1-1 2-3 2-4 3-5 4-2 5-2 5-6 6-7 6-8 7-9 8-10 8-11 9-12 10-13 11-14 12-15 15-16 16-17 16-18 16-19 17-20 17-21 17-23 18-22 31 | 0-0 1-3 2-1 3-5 4-5 5-6 6-7 6-8 6-9 7-10 8-11 9-11 10-12 10-13 11-12 12-14 14-15 15-17 15-18 16-16 17-19 17-20 18-22 19-21 20-23 32 | 0-2 1-1 2-0 2-2 4-4 5-5 6-6 7-7 8-8 9-3 10-9 11-10 12-11 13-12 14-15 15-13 16-14 18-19 19-18 19-21 20-19 21-20 22-27 25-23 26-28 27-25 28-30 33 | 0-0 0-1 1-1 2-1 4-2 5-2 6-4 34 | 0-0 0-1 0-2 1-2 2-4 3-3 3-5 4-6 5-7 5-8 5-9 6-10 7-11 7-12 7-13 8-13 8-14 9-15 10-16 11-17 35 | 0-0 1-1 2-2 3-3 4-3 5-4 6-5 7-6 8-8 8-9 9-10 10-11 11-17 12-12 13-13 14-14 15-15 16-16 17-18 36 | 0-0 1-2 2-1 2-3 3-4 4-5 4-6 5-11 6-7 7-8 7-9 8-8 9-10 10-11 11-12 37 | 0-1 0-2 1-0 1-4 2-3 2-5 2-6 3-7 3-8 3-9 4-10 38 | 0-0 0-1 1-15 2-6 3-2 4-7 5-7 6-8 6-9 8-11 9-5 10-3 10-17 11-16 39 | 0-0 1-1 1-2 2-3 3-3 4-4 5-5 6-6 7-8 7-9 8-10 8-13 9-11 9-12 10-12 11-14 40 | 0-0 0-1 1-6 2-7 3-2 4-3 4-4 4-5 6-9 7-12 8-10 8-11 9-13 10-14 11-15 12-15 12-16 14-18 41 | 0-0 1-1 2-2 3-3 4-4 6-9 7-8 8-7 8-9 8-10 8-11 9-13 10-12 11-14 11-15 11-16 12-17 13-17 13-18 14-19 42 | 0-0 1-1 2-2 3-3 4-4 5-5 6-9 7-9 8-11 9-12 43 | 0-1 1-2 2-2 2-3 3-3 4-4 5-5 6-6 7-8 9-9 10-10 11-12 12-13 13-12 14-11 14-13 15-10 16-14 17-15 18-16 19-17 20-18 44 | 0-0 1-1 2-2 3-4 4-3 5-5 6-6 7-8 8-9 9-9 10-10 11-11 11-12 11-13 12-14 13-15 14-16 45 | 0-0 0-1 1-2 2-3 3-4 4-4 4-5 5-6 5-7 5-12 6-8 7-9 8-10 9-10 9-11 10-11 11-13 46 | 0-0 0-1 1-1 1-3 2-4 3-3 4-4 4-5 4-6 5-2 7-7 8-8 9-14 10-9 11-10 12-11 13-12 14-13 15-15 16-22 17-16 18-17 19-17 20-18 21-19 22-20 23-21 24-22 25-23 47 | 0-0 1-1 2-2 3-2 4-5 5-5 6-4 7-6 9-12 10-11 11-8 12-9 13-10 14-13 15-15 16-14 17-25 18-16 20-18 20-21 21-19 22-20 23-22 24-23 25-24 26-26 48 | 0-0 0-1 1-2 1-3 2-4 2-5 2-6 3-7 4-8 5-9 5-10 5-11 6-11 6-12 7-13 8-14 9-15 9-16 9-17 10-18 10-20 11-19 12-21 49 | 0-0 1-1 2-2 3-3 4-4 5-5 6-7 6-8 7-17 8-9 9-10 10-11 11-12 12-13 13-14 14-15 15-16 16-17 17-18 50 | 0-0 0-1 1-2 1-3 3-4 4-5 4-9 5-10 6-6 7-7 8-8 9-11 51 | -------------------------------------------------------------------------------- /tests/toy-data/WMT18/word_level/en_de.nmt/dev.src_tags: -------------------------------------------------------------------------------- 1 | OK OK OK OK OK OK OK OK OK OK BAD OK BAD OK BAD BAD BAD BAD BAD OK OK OK OK OK 2 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 3 | OK OK OK OK OK OK OK OK OK OK OK BAD BAD OK OK OK OK OK OK OK OK OK OK 4 | OK OK OK OK OK OK OK OK OK BAD BAD BAD BAD OK OK OK OK 5 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 6 | OK OK OK OK BAD BAD OK OK OK BAD OK OK OK OK BAD OK 7 | OK OK OK OK OK OK OK BAD OK OK BAD OK BAD BAD BAD OK OK 8 | OK BAD BAD BAD BAD BAD OK OK OK BAD BAD BAD OK BAD BAD OK OK 9 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK BAD OK OK OK OK BAD BAD OK 10 | OK OK OK OK OK OK OK OK OK OK OK OK 11 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD BAD BAD OK 12 | OK OK OK OK OK OK OK OK OK OK BAD BAD BAD BAD OK OK OK OK OK OK OK 13 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK 14 | OK BAD OK OK OK BAD OK OK OK OK OK OK 15 | OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK BAD OK OK OK OK OK OK BAD OK OK 16 | OK OK OK OK OK OK OK BAD OK OK BAD OK OK OK OK OK OK 17 | BAD OK BAD OK BAD BAD BAD BAD BAD BAD BAD BAD OK OK OK OK 18 | OK OK OK OK OK OK OK OK OK OK 19 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK 20 | OK BAD BAD OK OK OK OK OK OK OK OK BAD OK OK OK OK BAD OK 21 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK 22 | OK OK OK OK OK OK BAD OK OK BAD OK OK OK OK OK OK OK OK OK OK OK 23 | OK OK OK OK BAD OK OK OK BAD OK BAD OK OK BAD BAD 24 | OK BAD BAD BAD BAD BAD BAD OK OK OK OK OK 25 | OK OK OK BAD BAD OK OK OK BAD BAD OK 26 | BAD BAD OK BAD BAD BAD BAD OK BAD BAD BAD BAD BAD OK OK OK OK OK OK OK OK OK OK OK 27 | OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK 28 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 29 | OK OK OK OK BAD OK OK BAD OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 30 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 31 | OK OK OK BAD BAD BAD BAD OK OK OK OK OK OK OK OK OK OK OK OK OK OK 32 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 33 | BAD OK OK OK OK OK OK 34 | OK OK OK OK OK OK OK OK OK OK OK OK 35 | BAD OK OK BAD BAD BAD OK BAD BAD OK OK OK OK OK OK OK OK OK 36 | OK OK OK BAD OK OK OK OK OK OK OK OK 37 | OK OK BAD BAD OK 38 | OK OK OK OK OK OK OK OK OK OK OK OK 39 | OK OK OK OK OK OK OK OK OK BAD OK OK 40 | OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK 41 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK 42 | OK OK OK OK OK OK BAD OK BAD OK 43 | OK OK OK OK OK OK BAD OK BAD BAD OK OK OK BAD OK OK OK BAD OK OK OK 44 | OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK 45 | OK OK OK OK OK OK OK BAD OK OK OK OK 46 | OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK OK OK OK OK OK OK BAD OK OK 47 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK OK OK 48 | OK OK OK OK OK OK OK OK OK OK OK OK OK 49 | BAD BAD BAD OK OK OK OK OK BAD OK BAD OK BAD OK BAD BAD BAD OK 50 | OK OK OK OK OK OK OK OK OK OK 51 | -------------------------------------------------------------------------------- /tests/toy-data/WMT18/word_level/en_de.nmt/train.hter: -------------------------------------------------------------------------------- 1 | 0.153846 2 | 0.230769 3 | 0.176471 4 | 0.000000 5 | 0.000000 6 | 0.083333 7 | 0.272727 8 | 0.000000 9 | 0.000000 10 | 0.142857 11 | 0.272727 12 | 0.600000 13 | 0.062500 14 | 0.636364 15 | 0.416667 16 | 0.041667 17 | 0.117647 18 | 0.000000 19 | 0.000000 20 | 0.111111 21 | 0.000000 22 | 0.000000 23 | 0.500000 24 | 0.041667 25 | 0.000000 26 | 0.052632 27 | 0.187500 28 | 0.705882 29 | 0.086957 30 | 0.117647 31 | 0.000000 32 | 0.000000 33 | 0.090909 34 | 0.090909 35 | 0.000000 36 | 0.080000 37 | 0.038462 38 | 0.000000 39 | 0.111111 40 | 0.038462 41 | 0.000000 42 | 0.000000 43 | 0.125000 44 | 0.150000 45 | 0.066667 46 | 0.047619 47 | 0.230769 48 | 0.333333 49 | 0.083333 50 | 0.000000 51 | 0.250000 52 | 0.052632 53 | 0.040000 54 | 0.000000 55 | 0.035714 56 | 0.066667 57 | 0.100000 58 | 0.733333 59 | 0.684211 60 | 0.812500 61 | 0.208333 62 | 0.090909 63 | 0.000000 64 | 0.294118 65 | 0.062500 66 | 0.285714 67 | 0.064516 68 | 0.000000 69 | 0.000000 70 | 0.285714 71 | 0.333333 72 | 0.739130 73 | 0.045455 74 | 0.250000 75 | 0.866667 76 | 0.045455 77 | 0.333333 78 | 0.080000 79 | 0.058824 80 | 0.083333 81 | 0.416667 82 | 0.000000 83 | 0.114286 84 | 0.000000 85 | 0.391304 86 | 0.000000 87 | 0.047619 88 | 0.043478 89 | 0.000000 90 | 0.000000 91 | 0.071429 92 | 0.000000 93 | 0.294118 94 | 0.038462 95 | 0.000000 96 | 0.200000 97 | 0.000000 98 | 0.100000 99 | 0.312500 100 | 0.210526 101 | -------------------------------------------------------------------------------- /tests/toy-data/models/bert/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | [CLS] 4 | [SEP] 5 | [MASK] 6 | 7 | 8 | . 9 | , 10 | the 11 | " 12 | sie 13 | die 14 | ##en 15 | to 16 | in 17 | - 18 | der 19 | ##s 20 | und 21 | and 22 | ##n 23 | ) 24 | auf 25 | an 26 | ( 27 | aus 28 | ##e 29 | you 30 | text 31 | a 32 | den 33 | of 34 | ##t 35 | or 36 | c 37 | for 38 | das 39 | mit 40 | im 41 | wird 42 | wa 43 | wenn 44 | zu 45 | option 46 | ze 47 | oder 48 | > 49 | um 50 | kl 51 | ##lick 52 | s 53 | ##er 54 | ein 55 | ##ick 56 | von 57 | ##be 58 | : 59 | eine 60 | ##ien 61 | is 62 | far 63 | fur 64 | ##hlen 65 | des 66 | werden 67 | ##wahl 68 | be 69 | sp 70 | ##te 71 | date 72 | konnen 73 | ##feld 74 | can 75 | select 76 | ##d 77 | ##ge 78 | menu 79 | ##fuge 80 | hinzu 81 | method 82 | ist 83 | ##f 84 | ##far 85 | ##l 86 | color 87 | that 88 | if 89 | from 90 | ##ten 91 | bil 92 | ang 93 | seit 94 | ##ieren 95 | ##zei 96 | ge 97 | links 98 | with 99 | selected 100 | add 101 | g 102 | ##ben 103 | ##ichen 104 | ##lach 105 | mocht 106 | format 107 | ##p 108 | -------------------------------------------------------------------------------- /tests/toy-data/models/nuqe.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unbabel/OpenKiwi/07d7cfed880457d95daf189dd8282e5b02ac2954/tests/toy-data/models/nuqe.ckpt -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # tox (https://tox.readthedocs.io/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | isolated_build = true 8 | skip_missing_interpreters = true 9 | parallel_show_output = true 10 | envlist = lint,py37,docs 11 | 12 | [testenv] 13 | whitelist_externals = 14 | poetry 15 | skip_install = true 16 | setenv = 17 | PYTHONHASHSEED=0 18 | PYTHONPATH={toxinidir} 19 | ;commands = 20 | ; poetry install -v 21 | ; {[testenv:test]commands} 22 | 23 | [testenv:test] 24 | skip_install = true 25 | usedevelop = true 26 | envdir = {toxworkdir}/py37 27 | commands = 28 | poetry install -v -E search -E mlflow 29 | poetry run coverage erase 30 | poetry run pytest --cov=kiwi --cov-report term --cov-report xml --cov-append {posargs:tests} 31 | 32 | [testenv:lint] 33 | skip_install = true 34 | usedevelop = true 35 | envdir = {toxworkdir}/py37 36 | commands = 37 | poetry install -v 38 | poetry run flake8 {posargs:kiwi} 39 | poetry run black --check {posargs:kiwi} 40 | poetry run isort --check-only --diff --recursive {posargs:kiwi} 41 | 42 | [testenv:py37] 43 | commands = 44 | poetry install -v 45 | {[testenv:test]commands} 46 | 47 | [testenv:docs] 48 | skip_install = true 49 | usedevelop = true 50 | ;envdir = {toxworkdir}/docs 51 | envdir = {toxworkdir}/py37 52 | changedir = {toxinidir}/docs 53 | commands = 54 | poetry install -v 55 | ; poetry run sphinx-apidoc -f -o source {toxinidir}/kiwi 56 | poetry run sphinx-build -b html -d ./.doctrees . ../public 57 | 58 | 59 | [testenv:gh-pages] 60 | skip_install = true 61 | usedevelop = true 62 | envdir = {toxworkdir}/py37 63 | commands = 64 | poetry install -v 65 | ; poetry run sphinx-apidoc -f -o docs/source {toxinidir}/kiwi 66 | poetry run sphinx-build -b html -d docs/.doctrees docs gh-pages 67 | 68 | # Other packages config 69 | 70 | [flake8] 71 | max_line_length = 88 72 | select = C,E,F,W,B,B950 73 | ignore = W503,E203 74 | 75 | [pytest] 76 | python_files = 77 | test_*.py 78 | *_test.py 79 | tests.py 80 | norecursedirs = 81 | .git 82 | .tox 83 | .env 84 | dist 85 | build 86 | 87 | [coverage:run] 88 | branch = true 89 | parallel = true 90 | omit = 91 | kiwi/__main__.py 92 | 93 | [coverage:report] 94 | exclude_lines = 95 | pragma: no cover 96 | if __name__ == .__main__.: 97 | --------------------------------------------------------------------------------