├── .circleci
└── config.yml
├── .editorconfig
├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── config
├── bert.yaml
├── data
│ ├── wmt19.qe.en_de.yaml
│ ├── wmt20.qe.en_de.yaml
│ └── wmt20.qe.en_zh.yaml
├── evaluate.yaml
├── nuqe.yaml
├── predict.yaml
├── predictor.yaml
├── predictor_estimator.yaml
├── search.yaml
├── xlm.yaml
└── xlmroberta.yaml
├── docs
├── _static
│ └── img
│ │ ├── openkiwi-logo-horizontal-dark.svg
│ │ ├── openkiwi-logo-horizontal.svg
│ │ ├── openkiwi-logo-icon-dark.svg
│ │ ├── openkiwi-logo-icon.ico
│ │ ├── openkiwi-logo-icon.svg
│ │ ├── openkiwi-logo-vertical-dark.svg
│ │ └── openkiwi-logo-vertical.svg
├── conf.py
├── configuration
│ ├── evaluate.rst
│ ├── index.rst
│ ├── predict.rst
│ ├── search.rst
│ └── train.rst
├── index.rst
├── installation.rst
├── systems
│ └── index.rst
├── test.json
└── usage.rst
├── kiwi
├── __init__.py
├── __main__.py
├── assets
│ ├── __init__.py
│ └── config
│ │ ├── __init__.py
│ │ ├── evaluate.yaml
│ │ ├── predict.yaml
│ │ ├── pretrain.yaml
│ │ ├── search.yaml
│ │ └── train.yaml
├── cli.py
├── constants.py
├── data
│ ├── __init__.py
│ ├── batch.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── parallel_dataset.py
│ │ └── wmt_qe_dataset.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── field_encoders.py
│ │ ├── parallel_data_encoder.py
│ │ └── wmt_qe_data_encoder.py
│ ├── tokenizers.py
│ └── vocabulary.py
├── lib
│ ├── __init__.py
│ ├── evaluate.py
│ ├── predict.py
│ ├── pretrain.py
│ ├── search.py
│ ├── train.py
│ └── utils.py
├── loggers.py
├── metrics
│ ├── __init__.py
│ ├── functions.py
│ └── metrics.py
├── modules
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── attention.py
│ │ ├── distributions.py
│ │ ├── feedforward.py
│ │ ├── layer_norm.py
│ │ ├── positional_encoding.py
│ │ ├── scalar_mix.py
│ │ └── scorer.py
│ ├── sentence_level_output.py
│ ├── token_embeddings.py
│ └── word_level_output.py
├── runner.py
├── systems
│ ├── __init__.py
│ ├── _meta_module.py
│ ├── bert.py
│ ├── decoders
│ │ ├── __init__.py
│ │ ├── estimator.py
│ │ ├── linear.py
│ │ └── nuqe.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── bert.py
│ │ ├── predictor.py
│ │ ├── quetch.py
│ │ ├── xlm.py
│ │ └── xlmroberta.py
│ ├── nuqe.py
│ ├── outputs
│ │ ├── __init__.py
│ │ ├── quality_estimation.py
│ │ └── translation_language_model.py
│ ├── predictor.py
│ ├── predictor_estimator.py
│ ├── qe_system.py
│ ├── tlm_system.py
│ ├── xlm.py
│ └── xlmroberta.py
├── training
│ ├── __init__.py
│ ├── callbacks.py
│ └── optimizers.py
└── utils
│ ├── __init__.py
│ ├── data_structures.py
│ ├── io.py
│ └── tensors.py
├── poetry.lock
├── pyproject.toml
├── scripts
└── merge_target_and_gaps_preds.py
├── tests
├── conftest.py
├── mocks
│ ├── mock_vocab.py
│ └── simple_model.py
├── test_bert.py
├── test_cli.py
├── test_data.py
├── test_metrics.py
├── test_nuqe.py
├── test_predict_and_evaluate.py
├── test_predictor.py
├── test_search.py
├── test_utils.py
├── test_xlm.py
├── test_xlmr.py
└── toy-data
│ ├── WMT18
│ └── word_level
│ │ └── en_de.nmt
│ │ ├── dev.hter
│ │ ├── dev.mt
│ │ ├── dev.pe
│ │ ├── dev.ref
│ │ ├── dev.src
│ │ ├── dev.src-mt.alignments
│ │ ├── dev.src_tags
│ │ ├── dev.tags
│ │ ├── train.hter
│ │ ├── train.mt
│ │ ├── train.pe
│ │ ├── train.ref
│ │ ├── train.src
│ │ ├── train.src-mt.alignments
│ │ ├── train.src_tags
│ │ └── train.tags
│ └── models
│ ├── bert
│ └── vocab.txt
│ └── nuqe.ckpt
└── tox.ini
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 | max_line_length = 88
13 |
14 | [*.json]
15 | indent_size = 2
16 |
17 | [*.bat]
18 | indent_style = tab
19 | end_of_line = crlf
20 |
21 | [LICENSE]
22 | insert_final_newline = false
23 |
24 | [Makefile]
25 | indent_style = tab
26 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Use data '...'
17 | 3. Run '....' with arguments '...'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Environment (please complete the following information):**
27 | - OS: [e.g. Linux]
28 | - OpenKiwi version [e.g. 0.1.0]
29 | - Python version
30 |
31 | **Additional context**
32 | Add any other context about the problem here.
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Kiwi related stuff
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | /lib/
20 | /lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | # vscode
109 | .vscode
110 |
111 | # Emacs Autosave
112 | *~
113 | \#*\#
114 | .\#*
115 |
116 | # Data Folder
117 | /data/
118 |
119 | # Logs Folder
120 | log/
121 |
122 | # torchtext vectors cache
123 | .vector_cache
124 |
125 | # MLFlow directory
126 | mlruns/
127 |
128 | # local run directory
129 | runs/
130 |
131 | # Pycharm project
132 | .idea/
133 |
134 | # Mac ignore
135 | .DS_Store
136 |
137 | # Folder used for save predictions
138 | /predictions/
139 |
140 | # Local shell scripts
141 | run_pipeline*.sh
142 |
143 | # Runners Directory Containing slurm output
144 | runners/
145 |
146 | #vim files
147 | *.swp
148 |
149 | # Generated docs
150 | /docs/html/
151 | /docs/.doctrees
152 |
153 | # Wheel Metadata
154 | pip-wheel-metadata/
155 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 |
2 | # Changelog
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7 |
8 | ## [2.1.0](https://github.com/Unbabel/OpenKiwi/compare/2.0.0...2.1.0) - 2020-11-12
9 |
10 | ### Added
11 | - Hyperparameter search pipeline `kiwi search` built on [Optuna](https://optuna.readthedocs.io/)
12 | - Docs for the search pipeline
13 | - The `--example` flag that has and example config from `kiwi/assests/conf/` printed to terminal for each pipeline
14 | - Tests to increase coverage
15 | - Readme link to the new [OpenKiwiTasting](https://github.com/Unbabel/OpenKiwiTasting) demo.
16 |
17 | ### Changed
18 | - Example configs in `conf/` so that they are clean, consistent, and have good defaults
19 | - Moved function `feedforward` from `kiwi.tensors` to `kiwi.modules.common.feedforward` where it makes more sense
20 |
21 | ### Fixed
22 | - The broken relative links in the docs
23 | - Evaluation pipeline by adding missing `quiet` and `verbose` in the evaluate configuration
24 |
25 | ### Deprecated
26 | - Migration of models from a previous OpenKiwi version, by removing the (never fully working) code in `kiwi.utils.migrations` entirely
27 |
28 | ### Removed
29 | - Unused code in `kiwi.training.optimizers`, `kiwi.modules.common.scorer`, `kiwi.modules.common.layer_norm`, `kiwi.modules.sentence_level_output`, `kiwi.metrics.metrics`, `kiwi.modules.common.attention`, `kiwi.modules.token_embeddings`
30 | - _All_ code that was already commented out
31 | - The `systems.encoder.(predictor|bert|xlm|xlmrobera).encode_source` option that is both _confusing_ as well as _never used_
32 |
33 | ## [2.0.0](https://github.com/Unbabel/OpenKiwi/compare/0.1.3...2.0.0)
34 |
35 | ### Added
36 | - XLMR, XLM, BERT encoder models
37 | - New pooling methods for xlmr-encoder [mixed, mean, ll_mean]
38 | - `freeze_for_number_of_steps` allows freezing of xlmr-encoder for a specific number of training steps
39 | - `encoder_learning_rate` allows to set a specific learning rate to be used on the encoder (different from the rest of the system)
40 | - Dataloaders now use a RandomBucketSampler which groups sentences of the same size together to minimize padding
41 | - fp16 support
42 | - Support for HuggingFace's transformers models
43 | - Pytorch-Lightning as a training framework
44 | - This changelog
45 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution guide
2 |
3 | ## Overview
4 |
5 | OpenKiwi is an Open Source Quality Estimation toolkit aimed at implementing state of the art models in an efficient and unified fashion. While we do welcome contributions, in order to guarantee their quality and usefulness, it is necessary that we follow basic guidelines in order to ease development, collaboration and readability.
6 |
7 | ## Basic guidelines
8 |
9 | * The project must fully support Python 3.5 or further.
10 | * Code is linted with [flake8](http://flake8.pycqa.org/en/latest/user/error-codes.html), please run `flake8 kiwi` and fix remaining errors before pushing any code.
11 | * Code formatting must stick to the Facebook style, 80 columns and single quotes. For Python 3.6+, the [black](https://github.com/ambv/black) formatter can be used by running `Black kiwi`. For python 3.5, [YAPF](https://github.com/google/yapf) should get most of the job done, although some manual changes might be necessary.
12 | * Imports are sorted with [isort](https://github.com/timothycrosley/isort).
13 | * Filenames must be in lowercase.
14 | * Tests are running with [pytest](https://docs.pytest.org/en/latest/) which is commonly referred to the best unittesting framework out there. Pytest implements a standard test discovery which means that it will only search for `test_*.py` or `*_test.py` files. We do not enforce a minimum code coverage but it is preferrable to have even very basic tests running for critical pieces of code. Always test functions that takes/returns tensor argument to document the sizes.
15 | * The `kiwi` folder contains core features. Any script calling these features must be placed into the `scripts` folder.
16 |
17 | ## Contributing
18 |
19 | * Keep track of everything by creating issues and editing them with reference to the code! Explain succinctly the problem you are trying to solve and your solution.
20 | * Contributions to `master` should be made through github pull-requests.
21 | * Dependencies are managed using `Poetry`. Although we would rather err on the side of less rather than more dependencies, if needed they are managed through the `pyproject.toml` file.
22 | * Work in a clean environment (`virtualenv` is nice).
23 | * Your commit message must start with an infinitive verb (Add, Fix, Remove, ...).
24 | * If your change is based on a paper, please include a clear comment and reference in the code and in the related issue.
25 | * In order to test your local changes, install OpenKiwi following the instructions on the [documentation](https://unbabel.github.io/openkiwi)
26 |
--------------------------------------------------------------------------------
/config/bert.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | ###################################################
3 | # Generic configurations options related to
4 | # handling of experiments
5 | experiment_name: BERT WMT20 EN-DE
6 | seed: 42
7 | use_mlflow: false
8 |
9 | trainer:
10 | ###################################################
11 | # Generic options related to the training process
12 | # that apply to all models
13 | deterministic: true
14 | gpus: -1
15 | epochs: 10
16 |
17 | main_metric:
18 | - WMT19_MCC
19 | - PEARSON
20 |
21 | gradient_max_norm: 1.
22 | gradient_accumulation_steps: 1
23 |
24 | # Control the model precision, see
25 | # https://pytorch-lightning.readthedocs.io/en/stable/amp.html
26 | # for more info on the configuration options
27 | # Fast mixed precision by default:
28 | amp_level: O2
29 | precision: 16
30 |
31 | log_interval: 100
32 | checkpoint:
33 | validation_steps: 0.2
34 | early_stop_patience: 10
35 |
36 |
37 | defaults:
38 | ###################################################
39 | # Example of composition of configuration files
40 | # this config is sourced from /config/data/wmt20.qe.en_de.yaml
41 | - data: wmt20.qe.en_de
42 |
43 | system:
44 | ###################################################
45 | # System configs are responsible for all the system
46 | # specific configurations. From model settings to
47 | # optimizers and specific processing options.
48 |
49 | # All configs must have either `class_name` or `load`
50 | class_name: Bert
51 |
52 | batch_size: 8
53 | num_data_workers: 4
54 |
55 | model:
56 | ################################################
57 | # Modeling options. These can change a lot about
58 | # the architecture of the system. With many configuration
59 | # options adding (or removing) layers.
60 | encoder:
61 | model_name: bert-base-multilingual-cased
62 | use_mlp: false
63 | freeze: false
64 |
65 | decoder:
66 | hidden_size: 768
67 | bottleneck_size: 768
68 | dropout: 0.1
69 |
70 | outputs:
71 | ####################################################
72 | # Output options configure the downstream tasks the
73 | # model will be trained on by adding specific layers
74 | # responsible for transforming decoder features into
75 | # predictions.
76 | word_level:
77 | target: true
78 | gaps: true
79 | source: true
80 | class_weights:
81 | target_tags:
82 | BAD: 3.0
83 | gap_tags:
84 | BAD: 5.0
85 | source_tags:
86 | BAD: 3.0
87 | sentence_level:
88 | hter: true
89 | use_distribution: false
90 | binary: false
91 | n_layers_output: 2
92 | sentence_loss_weight: 1
93 |
94 | tlm_outputs:
95 | fine_tune: false
96 |
97 | optimizer:
98 | class_name: adamw
99 | learning_rate: 1e-5
100 | warmup_steps: 0.1
101 | training_steps: 12000
102 |
103 | data_processing:
104 | share_input_fields_encoders: true
105 |
--------------------------------------------------------------------------------
/config/data/wmt19.qe.en_de.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train:
3 | input:
4 | source: data/WMT19/en_de.nmt/train.src
5 | target: data/WMT19/en_de.nmt/train.mt
6 | alignments: data/WMT19/en_de.nmt/train.src-mt.alignments
7 | post_edit: data/WMT19/en_de.nmt/train.pe
8 |
9 | output:
10 | source_tags: data/WMT19/en_de.nmt/train.source_tags
11 | target_tags: data/WMT19/en_de.nmt/train.tags
12 | sentence_scores: data/WMT19/en_de.nmt/train.hter
13 |
14 | valid:
15 | input:
16 | source: data/WMT19/en_de.nmt/dev.src
17 | target: data/WMT19/en_de.nmt/dev.mt
18 | alignments: data/WMT19/en_de.nmt/dev.src-mt.alignments
19 | post_edit: data/WMT19/en_de.nmt/dev.pe
20 |
21 | output:
22 | source_tags: data/WMT19/en_de.nmt/dev.source_tags
23 | target_tags: data/WMT19/en_de.nmt/dev.tags
24 | sentence_scores: data/WMT19/en_de.nmt/dev.hter
25 |
26 | test:
27 | input:
28 | source: data/WMT19/en_de.nmt/test.src
29 | target: data/WMT19/en_de.nmt/test.mt
30 | alignments: data/WMT19/en_de.nmt/test.src-mt.alignments
31 | post_edit: data/WMT19/en_de.nmt/test.pe
32 |
--------------------------------------------------------------------------------
/config/data/wmt20.qe.en_de.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train:
3 | input:
4 | source: data/WMT20/en-de/train/train.src
5 | target: data/WMT20/en-de/train/train.mt
6 | alignments: data/WMT20/en-de/train/train.src-mt.alignments
7 | post_edit: data/WMT20/en-de/train/train.pe
8 |
9 | output:
10 | source_tags: data/WMT20/en-de/train/train.source_tags
11 | target_tags: data/WMT20/en-de/train/train.tags
12 | sentence_scores: data/WMT20/en-de/train/train.hter
13 |
14 | valid:
15 | input:
16 | source: data/WMT20/en-de/dev/dev.src
17 | target: data/WMT20/en-de/dev/dev.mt
18 | alignments: data/WMT20/en-de/dev/dev.src-mt.alignments
19 | post_edit: data/WMT20/en-de/dev/dev.pe
20 |
21 | output:
22 | source_tags: data/WMT20/en-de/dev/dev.source_tags
23 | target_tags: data/WMT20/en-de/dev/dev.tags
24 | sentence_scores: data/WMT20/en-de/dev/dev.hter
25 |
26 | test:
27 | input:
28 | source: data/WMT20/en-de/test-blind/test.src
29 | target: data/WMT20/en-de/test-blind/test.mt
30 | alignments: data/WMT20/en-de/test-blind/test.src-mt.alignments
31 | post_edit: data/WMT20/en-de/test-blind/test-blind.pe
32 |
--------------------------------------------------------------------------------
/config/data/wmt20.qe.en_zh.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train:
3 | input:
4 | source: data/WMT20/en-zh/train/train.src
5 | target: data/WMT20/en-zh/train/train.mt
6 | alignments: data/WMT20/en-zh/train/train.src-mt.alignments
7 | post_edit: data/WMT20/en-zh/train/train.pe
8 |
9 | output:
10 | source_tags: data/WMT20/en-zh/train/train.source_tags
11 | target_tags: data/WMT20/en-zh/train/train.tags
12 | sentence_scores: data/WMT20/en-zh/train/train.hter
13 |
14 | valid:
15 | input:
16 | source: data/WMT20/en-zh/dev/dev.src
17 | target: data/WMT20/en-zh/dev/dev.mt
18 | alignments: data/WMT20/en-zh/dev/dev.src-mt.alignments
19 | post_edit: data/WMT20/en-zh/dev/dev.pe
20 |
21 | output:
22 | source_tags: data/WMT20/en-zh/dev/dev.source_tags
23 | target_tags: data/WMT20/en-zh/dev/dev.tags
24 | sentence_scores: data/WMT20/en-zh/dev/dev.hter
25 |
26 | test:
27 | input:
28 | source: data/WMT20/en-zh/test-blind/test.src
29 | target: data/WMT20/en-zh/test-blind/test.mt
30 | alignments: data/WMT20/en-zh/test-blind/test.src-mt.alignments
31 | post_edit: data/WMT20/en-zh/test-blind/test-blind.pe
32 |
--------------------------------------------------------------------------------
/config/evaluate.yaml:
--------------------------------------------------------------------------------
1 | gold_files:
2 | source_tags: data/WMT20/sentence_level/en_de.nmt/dev.source_tags
3 | target_tags: data/WMT20/sentence_level/en_de.nmt/dev.tags
4 | sentence_scores: data/WMT20/sentence_level/en_de.nmt/dev.hter
5 |
6 | # Two configuration options:
7 | # 1. (Recommended) Pass the root folders where the predictions live,
8 | # with the standard file names
9 | predicted_dir:
10 | # The evaluation pipeline supports evaluating multiple predictions at the same time
11 | # by passing the folders as a list
12 | - runs/0/4aa891368ff4402fa69a4b081ea2ba62
13 | - runs/0/e9200ada6dc84bfea807b3b02b9c7212
14 |
15 | # 2. Configure each predicted file separately
16 | # predicted_files:
17 | # source_tags:
18 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/source_tags
19 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/source_tags
20 | # # (Recommended) Pass the predicted `targetgaps_tags` file as `target_tags`;
21 | # # the target and gap tags will be separated and evaluated separately as well as jointly
22 | # target_tags:
23 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/targetgaps_tags
24 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/targetgaps_tags
25 | # # Alternatively:
26 | # # target_tags:
27 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/target_tags
28 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/target_tags
29 | # # gap_tags:
30 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/gap_tags
31 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/gap_tags
32 | # sentence_scores:
33 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/sentence_scores
34 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/sentence_scores
35 |
--------------------------------------------------------------------------------
/config/nuqe.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | experiment_name: NuQE WMT20 EN-DE
3 | seed: 42
4 | use_mlflow: false
5 |
6 | trainer:
7 | gpus: -1
8 | epochs: 10
9 |
10 | main_metric:
11 | - WMT19_MCC
12 | - PEARSON
13 |
14 | log_interval: 100
15 | checkpoint:
16 | validation_steps: 0.2
17 | early_stop_patience: 10
18 |
19 | defaults:
20 | - data: wmt20.qe.en_de
21 |
22 | system:
23 | class_name: NuQE
24 |
25 | batch_size: 64
26 | num_data_workers: 4
27 |
28 | model:
29 | encoder:
30 | window_size: 3
31 | embeddings:
32 | source:
33 | dim: 50
34 | dropout: 0.5
35 | freeze: false
36 | target:
37 | dim: 50
38 | dropout: 0.5
39 | freeze: false
40 |
41 | decoder:
42 | source:
43 | hidden_sizes: [400, 200, 100, 50]
44 | dropout: 0.
45 | target:
46 | hidden_sizes: [400, 200, 100, 50]
47 | dropout: 0.
48 |
49 | outputs:
50 | word_level:
51 | target: true
52 | gaps: true
53 | source: false
54 | class_weights:
55 | target_tags:
56 | BAD: 3.0
57 | gap_tags:
58 | BAD: 5.0
59 | source_tags:
60 | BAD: 3.0
61 | sentence_level:
62 | hter: true
63 | use_distribution: true
64 | binary: false
65 |
66 | data_processing:
67 | share_input_fields_encoders: false
68 | vocab:
69 | min_frequency: 2
70 |
71 | optimizer:
72 | class_name: adam
73 | learning_rate: 0.001
74 |
--------------------------------------------------------------------------------
/config/predict.yaml:
--------------------------------------------------------------------------------
1 | use_gpu: true
2 |
3 | run:
4 | output_dir: predictions
5 | predict_on_data_partition: test
6 |
7 | defaults:
8 | - data: wmt20.qe.en_de
9 |
10 | system:
11 | load: best_model.torch
12 |
13 | batch_size: 16
14 | num_data_workers: 0
15 |
16 | model:
17 | outputs:
18 | word_level:
19 | target: true
20 | gaps: true
21 | source: true
22 | sentence_level:
23 | hter: true
24 |
--------------------------------------------------------------------------------
/config/predictor.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | experiment_name: Predictor WMT20 EN-DE
3 | seed: 42
4 | use_mlflow: false
5 |
6 | trainer:
7 | deterministic: true
8 | gpus: -1
9 | epochs: 10
10 |
11 | log_interval: 100
12 | checkpoint:
13 | validation_steps: 0.2
14 | early_stop_patience: 10
15 |
16 | defaults:
17 | - data: wmt20.qe.en_de
18 |
19 | system:
20 | class_name: Predictor
21 |
22 | num_data_workers: 4
23 | batch_size:
24 | train: 32
25 | valid: 32
26 |
27 | model:
28 | encoder:
29 | encode_source: false
30 | hidden_size: 400
31 | rnn_layers: 2
32 | embeddings:
33 | source:
34 | dim: 200
35 | target:
36 | dim: 200
37 | out_embeddings_dim: 200
38 | share_embeddings: false
39 | dropout: 0.5
40 | use_mismatch_features: false
41 |
42 | optimizer:
43 | class_name: adam
44 | learning_rate: 0.001
45 | learning_rate_decay: 0.6
46 | learning_rate_decay_start: 2
47 |
48 | data_processing:
49 | vocab:
50 | min_frequency: 1
51 | max_size: 60_000
52 |
--------------------------------------------------------------------------------
/config/predictor_estimator.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | experiment_name: PredictorEstimator WMT20 EN-DE
3 | seed: 42
4 | use_mlflow: false
5 |
6 | trainer:
7 | gpus: -1
8 | epochs: 10
9 |
10 | main_metric:
11 | - WMT19_MCC
12 | - PEARSON
13 |
14 | log_interval: 100
15 | checkpoint:
16 | validation_steps: 0.2
17 | early_stop_patience: 10
18 |
19 | defaults:
20 | - data: wmt20.qe.en_de
21 |
22 | system:
23 | class_name: PredictorEstimator
24 |
25 | batch_size: 32
26 | num_data_workers: 4
27 |
28 | load_encoder: best_model.torch
29 |
30 | model:
31 | encoder:
32 | encode_source: false
33 | hidden_size: 400
34 | rnn_layers: 2
35 | embeddings:
36 | source:
37 | dim: 200
38 | target:
39 | dim: 200
40 | out_embeddings_dim: 200
41 | share_embeddings: false
42 | dropout: 0.5
43 | use_mismatch_features: false
44 |
45 | decoder:
46 | hidden_size: 125
47 | rnn_layers: 1
48 | use_mlp: true
49 | dropout: 0.0
50 |
51 | outputs:
52 | word_level:
53 | target: true
54 | gaps: false
55 | source: false
56 | class_weights:
57 | target_tags:
58 | BAD: 5.0
59 | gap_tags:
60 | BAD: 5.0
61 | source_tags:
62 | BAD: 3.0
63 | sentence_level:
64 | hter: true
65 | use_distribution: true
66 | binary: false
67 |
68 | tlm_outputs:
69 | fine_tune: true
70 |
71 | optimizer:
72 | class_name: adam
73 | learning_rate: 0.001
74 | learning_rate_decay: 0.6
75 | learning_rate_decay_start: 2
76 |
77 | data_processing:
78 | vocab:
79 | min_frequency: 1
80 | max_size: 60_000
81 |
--------------------------------------------------------------------------------
/config/search.yaml:
--------------------------------------------------------------------------------
1 | base_config: config/bert.yaml
2 |
3 | options:
4 | search_method: random
5 | # Search the model architecture
6 | # You can specify a list of values...
7 | hidden_size:
8 | - 768
9 | - 324
10 | # ...or you can specify a discrete range...
11 | bottleneck_size:
12 | lower: 100
13 | upper: 500
14 | step: 100
15 | search_mlp: true
16 | # Search optimizer
17 | # ...or you can specify a continuous interval.
18 | learning_rate:
19 | lower: 1e-7
20 | upper: 1e-5
21 | distribution: loguniform # recommended for the learning rate
22 | # Search weights for the tag loss
23 | class_weights:
24 | target_tags:
25 | lower: 1
26 | upper: 10
27 | step: 1
28 | gap_tags:
29 | lower: 1
30 | upper: 20
31 | step: 1
32 | source_tags: null
33 | # Search the sentence level objective
34 | search_hter: true
35 | sentence_loss_weight:
36 | lower: 1
37 | upper: 10
38 |
--------------------------------------------------------------------------------
/config/xlm.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | experiment_name: XLM WMT20 EN-DE
3 | seed: 42
4 | use_mlflow: false
5 |
6 | trainer:
7 | deterministic: true
8 | gpus: -1
9 | epochs: 10
10 |
11 | main_metric:
12 | - WMT19_MCC
13 | - PEARSON
14 |
15 | gradient_max_norm: 1.
16 | gradient_accumulation_steps: 1
17 |
18 | amp_level: O2
19 | precision: 16
20 |
21 | log_interval: 100
22 | checkpoint:
23 | validation_steps: 0.2
24 | early_stop_patience: 10
25 |
26 | defaults:
27 | - data: wmt20.qe.en_de
28 |
29 | system:
30 | class_name: XLM
31 |
32 | batch_size: 8
33 | num_data_workers: 4
34 |
35 | model:
36 | encoder:
37 | model_name: xlm-mlm-tlm-xnli15-1024
38 | interleave_input: false
39 | freeze: false
40 | use_mlp: false
41 |
42 | decoder:
43 | hidden_size: 768
44 | bottleneck_size: 768
45 | dropout: 0.1
46 |
47 | outputs:
48 | word_level:
49 | target: true
50 | gaps: true
51 | source: true
52 | class_weights:
53 | target_tags:
54 | BAD: 3.0
55 | gap_tags:
56 | BAD: 5.0
57 | source_tags:
58 | BAD: 3.0
59 | sentence_level:
60 | hter: true
61 | use_distribution: false
62 | binary: false
63 | n_layers_output: 2
64 | sentence_loss_weight: 2
65 |
66 | tlm_outputs:
67 | fine_tune: false
68 |
69 | optimizer:
70 | class_name: adamw
71 | learning_rate: 1e-05
72 | warmup_steps: 0.1
73 | training_steps: 12000
74 |
75 | data_processing:
76 | share_input_fields_encoders: true
77 |
--------------------------------------------------------------------------------
/config/xlmroberta.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | experiment_name: XLM-Roberta WMT20 EN-DE
3 | seed: 42
4 | use_mlflow: false
5 |
6 | trainer:
7 | deterministic: true
8 | gpus: -1
9 | epochs: 10
10 |
11 | main_metric:
12 | - WMT19_MCC
13 | - PEARSON
14 |
15 | gradient_max_norm: 1.
16 | gradient_accumulation_steps: 1
17 |
18 | amp_level: O2
19 | precision: 16
20 |
21 | log_interval: 100
22 | checkpoint:
23 | validation_steps: 0.2
24 | early_stop_patience: 10
25 |
26 | defaults:
27 | - data: wmt20.qe.en_de
28 |
29 | system:
30 | class_name: XLMRoberta
31 |
32 | batch_size: 8
33 | num_data_workers: 4
34 |
35 | model:
36 | encoder:
37 | model_name: xlm-roberta-base
38 | interleave_input: false
39 | freeze: false
40 | use_mlp: false
41 | pooling: mixed
42 | freeze_for_number_of_steps: 1000
43 |
44 | decoder:
45 | hidden_size: 768
46 | bottleneck_size: 768
47 | dropout: 0.1
48 |
49 | outputs:
50 | word_level:
51 | target: true
52 | gaps: true
53 | source: true
54 | class_weights:
55 | target_tags:
56 | BAD: 3.0
57 | gap_tags:
58 | BAD: 5.0
59 | source_tags:
60 | BAD: 3.0
61 | sentence_level:
62 | hter: true
63 | use_distribution: false
64 | binary: false
65 | n_layers_output: 2
66 | sentence_loss_weight: 1
67 |
68 | tlm_outputs:
69 | fine_tune: false
70 |
71 | optimizer:
72 | class_name: adamw
73 | learning_rate: 1e-05
74 | warmup_steps: 0.1
75 | training_steps: 12000
76 |
77 | data_processing:
78 | share_input_fields_encoders: true
79 |
--------------------------------------------------------------------------------
/docs/_static/img/openkiwi-logo-icon-dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
68 |
--------------------------------------------------------------------------------
/docs/_static/img/openkiwi-logo-icon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Unbabel/OpenKiwi/07d7cfed880457d95daf189dd8282e5b02ac2954/docs/_static/img/openkiwi-logo-icon.ico
--------------------------------------------------------------------------------
/docs/_static/img/openkiwi-logo-icon.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
58 |
--------------------------------------------------------------------------------
/docs/configuration/evaluate.rst:
--------------------------------------------------------------------------------
1 | Evaluation configuration
2 | ========================
3 |
4 | An example configuration file is provided in ``config/evaluate.yaml``:
5 |
6 | .. literalinclude:: ../../config/evaluate.yaml
7 | :language: yaml
8 |
9 |
10 | Configuration class
11 | -------------------
12 |
13 | Full API reference: :class:`kiwi.lib.evaluate.Configuration`
14 |
15 | .. autoclass:: kiwi.lib.evaluate.Configuration
16 | :noindex:
17 |
--------------------------------------------------------------------------------
/docs/configuration/index.rst:
--------------------------------------------------------------------------------
1 | .. _configuration:
2 |
3 | Configuration
4 | =============
5 |
6 | .. toctree::
7 | :maxdepth: 5
8 | :hidden:
9 |
10 | train
11 | predict
12 | evaluate
13 | search
14 |
15 |
16 | Kiwi can be configured essentially by using dictionaries. To persist configuration
17 | dictionaries, we recommend using YAML. You can find standard configuration files
18 | in ``config/``.
19 |
20 |
21 | CLI overrides
22 | -------------
23 |
24 | To run Kiwi with a configuration file but overriding a value, use the format
25 | ``key=value``, where nested keys can be encoded by using a dot notation.
26 | For example::
27 |
28 | kiwi train config/bert.yaml trainer.gpus=0 system.batch_size=16
29 |
30 |
31 | Configuration Composing
32 | -----------------------
33 |
34 | Kiwi uses Hydra/OmegaConf to compose configuration coming from different places.
35 | This makes it possible to split configuration across multiple files.
36 |
37 | In most files in ``config/``, like ``config/predict.yaml``, you'll notice this:
38 |
39 | .. code-block:: yaml
40 |
41 | defaults:
42 | - data: wmt19.qe.en_de
43 |
44 | This means the file ``config/data/wmt19.qe.en_de.yaml`` will be loaded into the
45 | configuration found in ``config/predict.yaml``. **Notice** that ``wmt19.qe.en_de.yaml``
46 | must use fully qualified keys levels, that is, the full nesting of keys.
47 |
48 | The nice use case for this is allowing dynamically changing parts of the configuration.
49 | For example, we can use::
50 |
51 | kiwi train config/bert.yaml data=unbabel.qe.en_pt
52 |
53 | to use a different dataset (where ``config/data/unbabel.qe.en_pt.yaml`` contains the
54 | configuration for the data files).
55 |
--------------------------------------------------------------------------------
/docs/configuration/predict.rst:
--------------------------------------------------------------------------------
1 | Prediction configuration
2 | ========================
3 |
4 | An example configuration file is provided in ``config/predict.yaml``:
5 |
6 | .. literalinclude:: ../../config/predict.yaml
7 | :language: yaml
8 |
9 |
10 | Configuration class
11 | -------------------
12 |
13 | Full API reference: :class:`kiwi.lib.predict.Configuration`
14 |
15 | .. autoclass:: kiwi.lib.predict.Configuration
16 | :noindex:
17 |
--------------------------------------------------------------------------------
/docs/configuration/search.rst:
--------------------------------------------------------------------------------
1 | Search configuration
2 | ======================
3 |
4 | There is an example search configuration file available in ``config/`` for the ``BERT`` model:
5 |
6 | .. literalinclude:: ../../config/search.yaml
7 | :language: yaml
8 |
9 |
10 | Configuration class
11 | -------------------
12 |
13 | Full API reference: :class:`kiwi.lib.search.Configuration`
14 |
15 |
16 | .. autosummary:
17 | :toctree: stubs
18 |
19 | .. kiwi.lib.search.Configuration
20 | kiwi.lib.search.SearchOptions
21 | kiwi.lib.search.ClassWeightsConfig
22 | kiwi.lib.search.RangeConfig
23 |
24 |
25 | .. autoclass:: kiwi.lib.search.Configuration
26 | :noindex:
27 |
--------------------------------------------------------------------------------
/docs/configuration/train.rst:
--------------------------------------------------------------------------------
1 | Training configuration
2 | ======================
3 |
4 | There are training configuration files available in ``config/`` for all supported models.
5 |
6 | As an example, here is the configuration for the ``Bert`` model:
7 |
8 | .. literalinclude:: ../../config/bert.yaml
9 | :language: yaml
10 |
11 |
12 | Configuration class
13 | -------------------
14 |
15 | Full API reference: :class:`kiwi.lib.train.Configuration`
16 |
17 |
18 | .. autosummary:
19 | :toctree: stubs
20 |
21 | .. kiwi.lib.train.Configuration
22 | kiwi.lib.train.RunConfig
23 | kiwi.lib.train.TrainerConfig
24 | kiwi.data.datasets.wmt_qe_dataset.WMTQEDataset.Config
25 | kiwi.systems.qe_system.QESystem.Config
26 |
27 |
28 | .. autoclass:: kiwi.lib.train.Configuration
29 | :noindex:
30 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | UnbabelKiwi's documentation
2 | ===========================
3 |
4 | .. image:: _static/img/openkiwi-logo-horizontal.svg
5 | :width: 400
6 | :alt: OpenKiwi by Unbabel
7 |
8 | ----
9 |
10 | .. mdinclude:: ../README.md
11 | :start-line: 3
12 |
13 | Documentation
14 | -------------
15 |
16 | .. toctree::
17 | :maxdepth: 5
18 |
19 | installation
20 | usage
21 | configuration/index
22 | systems/index
23 | autoapi/index
24 |
25 | .. code/index
26 |
27 | Indices and tables
28 | ------------------
29 |
30 | * :ref:`genindex`
31 | * :ref:`modindex`
32 | * :ref:`search`
33 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | .. _installation:
2 |
3 | Installation
4 | ============
5 |
6 | OpenKiwi can be installed in two ways, depending on how it's going to be used.
7 |
8 | As a library
9 | ------------
10 |
11 | Simply run::
12 |
13 | pip install openkiwi
14 |
15 | You can now::
16 |
17 | import kiwi
18 |
19 | inside your project or run in the command line::
20 |
21 | kiwi
22 |
23 |
24 | As a local package
25 | ------------------
26 |
27 | OpenKiwi's configuration is in ``pyproject.toml`` (as defined by PEP-518).
28 | We use `Poetry `_ as the build system
29 | and the dependency manager. All dependencies are specified in that file.
30 |
31 | Install Poetry via the recommended way::
32 |
33 | curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python
34 |
35 | It's also possible to use pip (but not recommended as it might mess up local dependencies)::
36 |
37 | pip install poetry
38 |
39 | In your virtualenv just run::
40 |
41 | poetry install
42 |
43 | to install all dependencies.
44 |
45 | That's it! Now, running::
46 |
47 | python kiwi -h
48 |
49 | or::
50 |
51 | kiwi -h
52 |
53 | should show you a help message.
54 |
55 |
56 | MLflow integration
57 | ------------------
58 |
59 | **Optionally**, to take advantage of our `MLflow `_ integration, install Kiwi with::
60 |
61 | pip install openkiwi[mlflow]
62 |
63 |
64 | **Or**::
65 |
66 | poetry install -E mlflow
67 |
68 |
69 | Hyperparameter search with Optuna
70 | ---------------------------------
71 |
72 | **Optionally**, to use the hyperparameter search pipeline with `Optuna `_,
73 | install Kiwi with::
74 |
75 | pip install openkiwi[search]
76 |
77 |
78 | **Or**::
79 |
80 | poetry install -E search
81 |
--------------------------------------------------------------------------------
/docs/systems/index.rst:
--------------------------------------------------------------------------------
1 | Systems
2 | =======
3 |
4 | There are two types of systems: **QE** and **TLM**. The first is the regular one used
5 | for Quality Estimation. The second, which stands for *Translation Language Model*, is
6 | used for pre-training the *encoder* component of a QE system.
7 |
8 |
9 | All systems, regardless of them being **QE** or **TLM** systems are constructed on top
10 | of PytorchLightning's (PTL) `LightningModule` class. This means they respect a certain design
11 | philosophy that can be consulted in PTL's documentation_.
12 |
13 | .. _documentation: https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
14 |
15 | Furthermore, all of our QE systems share a similar architecture. They are composed of
16 |
17 | - Encoder
18 | - Decoder
19 | - Output
20 | - (Optionally) TLM Output
21 |
22 | TLM systems on the other hand are only composed of the Encoder + TLM Output and their
23 | goal is to pre-train the encoder so it can be plugged into a QE system.
24 |
25 | The systems are divided into the 3 blocks that have the major responsabilities in both
26 | word and sentence level QE tasks:
27 |
28 | Encoder: Embedding and creating features to be used for downstream tasks. i.e. Predictor,
29 | BERT, etc
30 |
31 | Decoder: Responsible for learning feature transformations better suited for the
32 | downstream task. i.e. MLP, LSTM, etc
33 |
34 | Output: Simple feedforwards that take decoder features and transform them into the
35 | prediction required by the downstream task. Something in the same line as the common
36 | "classification heads" being used with transformers.
37 |
38 | TLM Output: A simple output layer that trains for the specific TLM objective. It can be
39 | useful to continue finetuning the predictor during training of the complete QE system.
40 |
41 |
42 | QE --- :mod:`kiwi.systems.qe_system`
43 | ------------------------------------
44 |
45 | All QE systems inherit from :class:`kiwi.systems.qe_system.QESystem`.
46 |
47 | Use ``kiwi train`` to train these systems.
48 |
49 | Currently available are:
50 |
51 | +--------------------------------------------------------------+
52 | | :class:`kiwi.systems.nuqe.NuQE` |
53 | +--------------------------------------------------------------+
54 | | :class:`kiwi.systems.predictor_estimator.PredictorEstimator` |
55 | +--------------------------------------------------------------+
56 | | :class:`kiwi.systems.bert.Bert` |
57 | +--------------------------------------------------------------+
58 | | :class:`kiwi.systems.xlm.XLM` |
59 | +--------------------------------------------------------------+
60 | | :class:`kiwi.systems.xlmroberta.XLMRoberta` |
61 | +--------------------------------------------------------------+
62 |
63 |
64 | TLM --- :mod:`kiwi.systems.tlm_system`
65 | --------------------------------------
66 |
67 | All TLM systems inherit from :class:`kiwi.systems.tlm_system.TLMSystem`.
68 |
69 | Use ``kiwi pretrain`` to train these systems. These systems can then be used as the
70 | encoder part in QE systems by using the `load_encoder` flag.
71 |
72 | Currently available are:
73 |
74 | +-------------------------------------------+
75 | | :class:`kiwi.systems.predictor.Predictor` |
76 | +-------------------------------------------+
77 |
--------------------------------------------------------------------------------
/kiwi/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from kiwi.lib.train import train_from_file # NOQA isort:skip
18 | from kiwi.lib.predict import load_system # NOQA isort:skip
19 |
20 | __version__ = '2.1.0'
21 | __copyright__ = '2019-2020 Unbabel. All rights reserved.'
22 |
--------------------------------------------------------------------------------
/kiwi/__main__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import kiwi.cli
18 |
19 |
20 | def main():
21 | return kiwi.cli.cli()
22 |
23 |
24 | if __name__ == '__main__':
25 | main()
26 |
--------------------------------------------------------------------------------
/kiwi/assets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Unbabel/OpenKiwi/07d7cfed880457d95daf189dd8282e5b02ac2954/kiwi/assets/__init__.py
--------------------------------------------------------------------------------
/kiwi/assets/config/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 |
4 | def file_path(name):
5 | package_directory = Path(__file__).parent
6 | return package_directory / name
7 |
--------------------------------------------------------------------------------
/kiwi/assets/config/evaluate.yaml:
--------------------------------------------------------------------------------
1 | gold_files:
2 | source_tags: data/WMT20/sentence_level/en_de.nmt/dev.source_tags
3 | target_tags: data/WMT20/sentence_level/en_de.nmt/dev.tags
4 | sentence_scores: data/WMT20/sentence_level/en_de.nmt/dev.hter
5 |
6 | # Two configuration options:
7 | # 1. (Recommended) Pass the root folders where the predictions live,
8 | # with the standard file names
9 | predicted_dir:
10 | # The evaluation pipeline supports evaluating multiple predictions at the same time
11 | # by passing the folders as a list
12 | - runs/0/4aa891368ff4402fa69a4b081ea2ba62
13 | - runs/0/e9200ada6dc84bfea807b3b02b9c7212
14 |
15 | # 2. Configure each predicted file separately
16 | # predicted_files:
17 | # source_tags:
18 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/source_tags
19 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/source_tags
20 | # # (Recommended) Pass the predicted `targetgaps_tags` file as `target_tags`;
21 | # # the target and gap tags will be separated and evaluated separately as well as jointly
22 | # target_tags:
23 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/targetgaps_tags
24 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/targetgaps_tags
25 | # # Alternatively:
26 | # # target_tags:
27 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/target_tags
28 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/target_tags
29 | # # gap_tags:
30 | # # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/gap_tags
31 | # # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/gap_tags
32 | # sentence_scores:
33 | # - runs/0/4aa891368ff4402fa69a4b081ea2ba62/sentence_scores
34 | # - runs/0/e9200ada6dc84bfea807b3b02b9c7212/sentence_scores
35 |
--------------------------------------------------------------------------------
/kiwi/assets/config/predict.yaml:
--------------------------------------------------------------------------------
1 | use_gpu: true
2 |
3 | run:
4 | output_dir: predictions
5 | predict_on_data_partition: test
6 |
7 | defaults:
8 | - data: wmt20.qe.en_de
9 |
10 | system:
11 | load: best_model.torch
12 |
13 | batch_size: 16
14 | num_data_workers: 0
15 |
16 | model:
17 | outputs:
18 | word_level:
19 | target: true
20 | gaps: true
21 | source: true
22 | sentence_level:
23 | hter: true
24 |
--------------------------------------------------------------------------------
/kiwi/assets/config/pretrain.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | experiment_name: Predictor WMT20 EN-DE
3 | seed: 42
4 | use_mlflow: false
5 |
6 | trainer:
7 | deterministic: true
8 | gpus: -1
9 | epochs: 10
10 |
11 | log_interval: 100
12 | checkpoint:
13 | validation_steps: 0.2
14 | early_stop_patience: 10
15 |
16 | defaults:
17 | - data: wmt20.qe.en_de
18 |
19 | system:
20 | class_name: Predictor
21 |
22 | num_data_workers: 4
23 | batch_size:
24 | train: 32
25 | valid: 32
26 |
27 | model:
28 | encoder:
29 | hidden_size: 400
30 | rnn_layers: 2
31 | embeddings:
32 | source:
33 | dim: 200
34 | target:
35 | dim: 200
36 | out_embeddings_dim: 200
37 | share_embeddings: false
38 | dropout: 0.5
39 | use_mismatch_features: false
40 |
41 | optimizer:
42 | class_name: adam
43 | learning_rate: 0.001
44 | learning_rate_decay: 0.6
45 | learning_rate_decay_start: 2
46 |
47 | data_processing:
48 | vocab:
49 | min_frequency: 1
50 | max_size: 60_000
51 |
--------------------------------------------------------------------------------
/kiwi/assets/config/search.yaml:
--------------------------------------------------------------------------------
1 | base_config: config/bert.yaml
2 |
3 | options:
4 | search_method: random
5 | # Search the model architecture
6 | # You can specify a list of values...
7 | hidden_size:
8 | - 768
9 | - 324
10 | # ...or you can specify a discrete range...
11 | bottleneck_size:
12 | lower: 100
13 | upper: 500
14 | step: 100
15 | search_mlp: true
16 | # Search optimizer
17 | # ...or you can specify a continuous interval.
18 | learning_rate:
19 | lower: 1e-7
20 | upper: 1e-5
21 | distribution: loguniform # recommended for the learning rate
22 | # Search weights for the tag loss
23 | class_weights:
24 | target_tags:
25 | lower: 1
26 | upper: 10
27 | step: 1
28 | gap_tags:
29 | lower: 1
30 | upper: 20
31 | step: 1
32 | source_tags: null
33 | # Search the sentence level objective
34 | search_hter: true
35 | sentence_loss_weight:
36 | lower: 1
37 | upper: 10
38 |
--------------------------------------------------------------------------------
/kiwi/assets/config/train.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | ###################################################
3 | # Generic configurations options related to
4 | # handling of experiments
5 | experiment_name: BERT WMT20 EN-DE
6 | seed: 42
7 | use_mlflow: false
8 |
9 | trainer:
10 | ###################################################
11 | # Generic options related to the training process
12 | # that apply to all models
13 | deterministic: true
14 | gpus: -1
15 | epochs: 10
16 |
17 | main_metric:
18 | - WMT19_MCC
19 | - PEARSON
20 |
21 | gradient_max_norm: 1.
22 | gradient_accumulation_steps: 1
23 |
24 | # Control the model precision, see
25 | # https://pytorch-lightning.readthedocs.io/en/stable/amp.html
26 | # for more info on the configuration options
27 | # Fast mixed precision by default:
28 | amp_level: O2
29 | precision: 16
30 |
31 | log_interval: 100
32 | checkpoint:
33 | validation_steps: 0.2
34 | early_stop_patience: 10
35 |
36 |
37 | defaults:
38 | ###################################################
39 | # Example of composition of configuration files
40 | # this config is sourced from /config/data/wmt20.qe.en_de.yaml
41 | - data: wmt20.qe.en_de
42 |
43 | system:
44 | ###################################################
45 | # System configs are responsible for all the system
46 | # specific configurations. From model settings to
47 | # optimizers and specific processing options.
48 |
49 | # All configs must have either `class_name` or `load`
50 | class_name: Bert
51 |
52 | batch_size: 8
53 | num_data_workers: 4
54 |
55 | model:
56 | ################################################
57 | # Modeling options. These can change a lot about
58 | # the architecture of the system. With many configuration
59 | # options adding (or removing) layers.
60 | encoder:
61 | model_name: bert-base-multilingual-cased
62 | use_mlp: false
63 | freeze: false
64 |
65 | decoder:
66 | hidden_size: 768
67 | bottleneck_size: 768
68 | dropout: 0.1
69 |
70 | outputs:
71 | ####################################################
72 | # Output options configure the downstream tasks the
73 | # model will be trained on by adding specific layers
74 | # responsible for transforming decoder features into
75 | # predictions.
76 | word_level:
77 | target: true
78 | gaps: true
79 | source: true
80 | class_weights:
81 | target_tags:
82 | BAD: 3.0
83 | gap_tags:
84 | BAD: 5.0
85 | source_tags:
86 | BAD: 3.0
87 | sentence_level:
88 | hter: true
89 | use_distribution: false
90 | binary: false
91 | n_layers_output: 2
92 | sentence_loss_weight: 1
93 |
94 | tlm_outputs:
95 | fine_tune: false
96 |
97 | optimizer:
98 | class_name: adamw
99 | learning_rate: 1e-5
100 | warmup_steps: 0.1
101 | training_steps: 12000
102 |
103 | data_processing:
104 | share_input_fields_encoders: true
105 |
--------------------------------------------------------------------------------
/kiwi/cli.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | """
18 | Kiwi runner
19 | ~~~~~~~~~~~
20 |
21 | Quality Estimation toolkit.
22 |
23 | Invoke as ``kiwi PIPELINE``.
24 |
25 | Usage:
26 | kiwi [options] (train|pretrain|predict|evaluate|search) CONFIG_FILE [OVERWRITES ...]
27 | kiwi (train|pretrain|predict|evaluate|search) --example
28 | kiwi (-h | --help | --version)
29 |
30 |
31 | Pipelines:
32 | train Train a QE model
33 | pretrain Pretrain a TLM model to be used as an encoder for a QE model
34 | predict Use a pre-trained model for prediction
35 | evaluate Evaluate a model's predictions using popular metrics
36 | search Search training hyper-parameters for a QE model
37 |
38 | Disabled pipelines:
39 | jackknife Jackknife training data with model
40 |
41 | Arguments:
42 | CONFIG_FILE configuration file to use (e.g., config/nuqe.yaml)
43 | OVERWRITES key=value to overwrite values in CONFIG_FILE; use ``key.subkey``
44 | for nested keys.
45 |
46 | Options:
47 | -v --verbose log debug messages
48 | -q --quiet log only warning and error messages
49 | -h --help show this help message and exit
50 | --version show version and exit
51 | --example print an example configuration file
52 |
53 | """
54 | import sys
55 |
56 | from docopt import docopt
57 |
58 | from kiwi import __version__
59 | from kiwi.assets import config
60 | from kiwi.lib import evaluate, predict, pretrain, search, train
61 | from kiwi.lib.utils import arguments_to_configuration
62 |
63 |
64 | def handle_example(arguments, caller):
65 |
66 | if arguments.get('--example'):
67 | conf_file = f'{caller}.yaml'
68 | print(config.file_path(conf_file).read_text())
69 | print(
70 | f'# Save the above in a file called {conf_file} and then run:\n'
71 | f'# {" ".join(sys.argv[:-1])} {conf_file}'
72 | )
73 | sys.exit(0)
74 |
75 |
76 | def cli():
77 | arguments = docopt(
78 | __doc__, argv=sys.argv[1:], help=True, version=__version__, options_first=False
79 | )
80 |
81 | if arguments['train']:
82 | handle_example(arguments, 'train')
83 | config_dict = arguments_to_configuration(arguments)
84 | train.train_from_configuration(config_dict)
85 | if arguments['predict']:
86 | handle_example(arguments, 'predict')
87 | config_dict = arguments_to_configuration(arguments)
88 | predict.predict_from_configuration(config_dict)
89 | if arguments['pretrain']:
90 | handle_example(arguments, 'pretrain')
91 | config_dict = arguments_to_configuration(arguments)
92 | pretrain.pretrain_from_configuration(config_dict)
93 | if arguments['evaluate']:
94 | handle_example(arguments, 'evaluate')
95 | config_dict = arguments_to_configuration(arguments)
96 | evaluate.evaluate_from_configuration(config_dict)
97 | if arguments['search']:
98 | handle_example(arguments, 'search')
99 | config_dict = arguments_to_configuration(arguments)
100 | search.search_from_configuration(config_dict)
101 | # Meta Pipelines
102 | # if options.pipeline == 'jackknife':
103 | # jackknife.main(extra_args)
104 |
105 |
106 | if __name__ == '__main__': # pragma: no cover
107 | cli() # pragma: no cover
108 |
--------------------------------------------------------------------------------
/kiwi/constants.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
18 | # lowercased special tokens
19 | UNK = ''
20 | PAD = ''
21 | START = ''
22 | STOP = ''
23 | UNALIGNED = ''
24 |
25 | # binary labels
26 | OK = 'OK'
27 | BAD = 'BAD'
28 | LABELS = [OK, BAD]
29 |
30 | SOURCE = 'source'
31 | TARGET = 'target'
32 | PE = 'pe'
33 | TARGET_TAGS = 'target_tags'
34 | SOURCE_TAGS = 'source_tags'
35 | GAP_TAGS = 'gap_tags'
36 | TARGETGAPS_TAGS = 'targetgaps_tags'
37 |
38 | SOURCE_LOGITS = 'source_logits'
39 | TARGET_LOGITS = 'target_logits'
40 | TARGET_SENTENCE = 'target_sentence'
41 | PE_LOGITS = 'pe_logits'
42 |
43 | SENTENCE_SCORES = 'sentence_scores'
44 | BINARY = 'binary'
45 |
46 | ALIGNMENTS = 'alignments'
47 | SOURCE_POS = 'source_pos'
48 | TARGET_POS = 'target_pos'
49 |
50 | # Constants for model output names
51 | LOSS = 'loss'
52 |
53 | # Standard Names for saving files
54 | VOCAB = 'vocab'
55 | CONFIG = 'config'
56 | STATE_DICT = 'state_dict'
57 |
--------------------------------------------------------------------------------
/kiwi/data/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/data/batch.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from dataclasses import dataclass
18 |
19 | import torch
20 | import torchnlp.utils
21 | from torchnlp.encoders.text import BatchedSequences
22 |
23 |
24 | @dataclass
25 | class BatchedSentence:
26 | tensor: torch.Tensor
27 | lengths: torch.Tensor
28 | bounds: torch.Tensor
29 | bounds_lengths: torch.Tensor
30 | strict_masks: torch.Tensor
31 | number_of_tokens: torch.Tensor
32 |
33 | def pin_memory(self):
34 | self.tensor = self.tensor.pin_memory()
35 | self.lengths = self.lengths.pin_memory()
36 | self.bounds = self.bounds.pin_memory()
37 | self.bounds_lengths = self.bounds_lengths.pin_memory()
38 | self.strict_masks = self.strict_masks.pin_memory()
39 | self.number_of_tokens = self.number_of_tokens.pin_memory()
40 | return self
41 |
42 | def to(self, *args, **kwargs):
43 | self.tensor = self.tensor.to(*args, **kwargs)
44 | self.lengths = self.lengths.to(*args, **kwargs)
45 | self.bounds = self.bounds.to(*args, **kwargs)
46 | self.bounds_lengths = self.bounds_lengths.to(*args, **kwargs)
47 | self.strict_masks = self.strict_masks.to(*args, **kwargs)
48 | self.number_of_tokens = self.number_of_tokens.to(*args, **kwargs)
49 | return self
50 |
51 |
52 | class MultiFieldBatch(dict):
53 | def __init__(self, batch: dict):
54 | super().__init__()
55 | self.update(batch)
56 |
57 | def pin_memory(self):
58 | for field, data in self.items():
59 | if isinstance(data, BatchedSequences):
60 | tensor = data.tensor.pin_memory()
61 | lengths = data.lengths.pin_memory()
62 | self[field] = BatchedSequences(tensor=tensor, lengths=lengths)
63 | else:
64 | self[field] = data.pin_memory()
65 | return self
66 |
67 | def to(self, *args, **kwargs):
68 | for field, data in self.items():
69 | self[field] = data.to(*args, **kwargs)
70 | return self
71 |
72 |
73 | def tensors_to(tensors, *args, **kwargs):
74 | if isinstance(tensors, (MultiFieldBatch, BatchedSentence)):
75 | return tensors.to(*args, **kwargs)
76 | elif isinstance(tensors, dict):
77 | return {k: tensors_to(v, *args, **kwargs) for k, v in tensors.items()}
78 | else:
79 | return torchnlp.utils.tensors_to(tensors, *args, **kwargs)
80 |
--------------------------------------------------------------------------------
/kiwi/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/data/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/data/encoders/base.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | class DataEncoders:
18 | def __init__(self):
19 | pass
20 |
21 | @property
22 | def vocabularies(self):
23 | raise NotImplementedError
24 |
--------------------------------------------------------------------------------
/kiwi/data/tokenizers.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | def tokenize(sentence):
18 | """Implement your own tokenize procedure."""
19 | return sentence.strip().split()
20 |
21 |
22 | def detokenize(tokens):
23 | return ' '.join(tokens)
24 |
25 |
26 | def align_tokenize(s):
27 | """Return a list of pair of integers for each sentence."""
28 | return [tuple(map(int, x.split('-'))) for x in s.strip().split()]
29 |
30 |
31 | def bert_tokenizer(sentence):
32 | raise NotImplementedError
33 |
--------------------------------------------------------------------------------
/kiwi/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/lib/pretrain.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 |
19 | from kiwi.lib.train import Configuration as TrainConfig
20 | from kiwi.lib.train import TrainRunInfo, run
21 | from kiwi.lib.utils import file_to_configuration
22 | from kiwi.systems.tlm_system import TLMSystem
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | class Configuration(TrainConfig):
28 | system: TLMSystem.Config
29 |
30 |
31 | def pretrain_from_file(filename) -> TrainRunInfo:
32 | """Load options from a config file and call the pretraining procedure.
33 |
34 | Arguments:
35 | filename: of the configuration file.
36 |
37 | Return:
38 | object with training information.
39 | """
40 | config = file_to_configuration(filename)
41 | return pretrain_from_configuration(config)
42 |
43 |
44 | def pretrain_from_configuration(configuration_dict) -> TrainRunInfo:
45 | """Run the entire training pipeline using the configuration options received.
46 |
47 | Arguments:
48 | configuration_dict: dictionary with config options.
49 |
50 | Return:
51 | object with training information.
52 | """
53 | config = Configuration(**configuration_dict)
54 | train_info = run(config, TLMSystem)
55 | return train_info
56 |
--------------------------------------------------------------------------------
/kiwi/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from kiwi.metrics import metrics
18 |
19 | F1MultMetric = metrics.F1MultMetric
20 | ExpectedErrorMetric = metrics.ExpectedErrorMetric
21 | PerplexityMetric = metrics.PerplexityMetric
22 | CorrectMetric = metrics.CorrectMetric
23 | RMSEMetric = metrics.RMSEMetric
24 | PearsonMetric = metrics.PearsonMetric
25 | SpearmanMetric = metrics.SpearmanMetric
26 | MatthewsMetric = metrics.MatthewsMetric
27 | # ThresholdCalibrationMetric = metrics.ThresholdCalibrationMetric
28 |
--------------------------------------------------------------------------------
/kiwi/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/modules/common/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/modules/common/activations.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import math
18 |
19 | import torch
20 |
21 |
22 | def gelu(x):
23 | """gelu activation function copied from pytorch-pretrained-BERT."""
24 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
25 |
26 |
27 | def swish(x):
28 | """swusg activation function copied from pytorch-pretrained-BERT."""
29 | return x * torch.sigmoid(x)
30 |
--------------------------------------------------------------------------------
/kiwi/modules/common/attention.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import torch
18 | from torch import nn
19 | from torch.nn import functional as F
20 |
21 | from kiwi.utils.tensors import unsqueeze_as
22 |
23 |
24 | class Attention(nn.Module):
25 | """Generic Attention Implementation.
26 |
27 | 1. Use `query` and `keys` to compute scores (energies)
28 | 2. Apply softmax to get attention probabilities
29 | 3. Perform a dot product between `values` and probabilites (outputs)
30 |
31 | Arguments:
32 | scorer (kiwi.modules.common.Scorer): a scorer object
33 | dropout (float): dropout rate after softmax (default: 0.)
34 | """
35 |
36 | def __init__(self, scorer, dropout=0):
37 | super().__init__()
38 | self.scorer = scorer
39 | self.dropout = nn.Dropout(p=dropout)
40 | self.NEG_INF = -1e9 # for masking attention scores before softmax
41 |
42 | def forward(self, query, keys, values=None, mask=None):
43 | """Compute the attention between query, keys and values.
44 |
45 | Arguments:
46 | query (torch.Tensor): set of query vectors with shape of
47 | (batch_size, ..., target_len, hidden_size)
48 | keys (torch.Tensor): set of keys vectors with shape of
49 | (batch_size, ..., source_len, hidden_size)
50 | values (torch.Tensor, optional): set of values vectors with
51 | shape of: (batch_size, ..., source_len, hidden_size).
52 | If None, keys are treated as values. Default: None
53 | mask (torch.ByteTensor, optional): Tensor representing valid
54 | positions. If None, all positions are considered valid.
55 | Shape of (batch_size, target_len)
56 |
57 | Return:
58 | torch.Tensor: combination of values and attention probabilities.
59 | Shape of (batch_size, ..., target_len, hidden_size)
60 | torch.Tensor: attention probabilities between query and keys.
61 | Shape of (batch_size, ..., target_len, source_len)
62 | """
63 | if values is None:
64 | values = keys
65 |
66 | # get scores (aka energies)
67 | scores = self.scorer(query, keys)
68 |
69 | # mask out scores to infinity before softmax
70 | if mask is not None:
71 | # broadcast in keys' timestep dim many times as needed
72 | mask = unsqueeze_as(mask, scores, dim=-2)
73 | scores = scores.masked_fill(mask == 0, self.NEG_INF)
74 |
75 | # apply softmax to get probs
76 | p_attn = F.softmax(scores, dim=-1)
77 |
78 | # apply dropout - used in Transformer (default: 0)
79 | p_attn = self.dropout(p_attn)
80 |
81 | # dot product between p_attn and values
82 | # o_attn = torch.matmul(p_attn, values)
83 | o_attn = torch.einsum('b...ts,b...sm->b...tm', [p_attn, values])
84 | return o_attn, p_attn
85 |
--------------------------------------------------------------------------------
/kiwi/modules/common/distributions.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from torch.distributions import (
18 | Normal,
19 | TransformedDistribution,
20 | constraints,
21 | identity_transform,
22 | )
23 |
24 |
25 | class TruncatedNormal(TransformedDistribution):
26 | arg_constraints = {
27 | 'loc': constraints.real,
28 | 'scale': constraints.positive,
29 | 'lower_bound': constraints.real,
30 | 'upper_bound': constraints.real,
31 | }
32 | support = constraints.real
33 | has_rsample = True
34 |
35 | def __init__(
36 | self, loc, scale, lower_bound=0.0, upper_bound=1.0, validate_args=None
37 | ):
38 | base_dist = Normal(loc, scale)
39 | super(TruncatedNormal, self).__init__(
40 | base_dist, identity_transform, validate_args=validate_args
41 | )
42 | self.lower_bound = lower_bound
43 | self.upper_bound = upper_bound
44 |
45 | def partition_function(self):
46 | return (
47 | self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound)
48 | ).detach() + 1e-12
49 |
50 | @property
51 | def scale(self):
52 | return self.base_dist.scale
53 |
54 | @property
55 | def mean(self):
56 | r"""
57 | :math:`pdf = f(x; \mu, \sigma, a, b) = \frac{\phi(\xi)}{\sigma Z}`
58 |
59 | :math:`\xi=\frac{x-\mu}{\sigma}`
60 |
61 | :math:`\alpha=\frac{a-\mu}{\sigma}`
62 |
63 | :math:`\beta=\frac{b-\mu}{\sigma}`
64 |
65 | :math:`Z=\Phi(\beta)-\Phi(\alpha)`
66 |
67 | Return:
68 | :math:`\mu + \frac{\phi(\alpha)-\phi(\beta)}{Z}\sigma`
69 |
70 | """
71 | mean = self.base_dist.mean + (
72 | (
73 | self.base_dist.scale ** 2
74 | * (
75 | self.base_dist.log_prob(self.lower_bound).exp()
76 | - self.base_dist.log_prob(self.upper_bound).exp()
77 | )
78 | )
79 | / self.partition_function()
80 | )
81 | return mean
82 |
83 | @property
84 | def variance(self):
85 | pdf_a = self.base_dist.log_prob(self.lower_bound).exp()
86 | pdf_b = self.base_dist.log_prob(self.upper_bound).exp()
87 | alpha = (
88 | self.lower_bound - self.base_dist.mean
89 | ) * self.base_dist.scale.reciprocal()
90 | beta = (
91 | self.upper_bound - self.base_dist.mean
92 | ) * self.base_dist.scale.reciprocal()
93 | z = self.partition_function()
94 | term1 = (alpha * pdf_a - beta * pdf_b) / z
95 | term2 = (pdf_a - pdf_b) / z
96 | return self.base_dist.scale ** 2 * (1 + term1 - term2 ** 2)
97 |
98 | def log_prob(self, value):
99 | log_value = self.base_dist.log_prob(value)
100 | log_prob = log_value - self.partition_function().log()
101 | return log_prob
102 |
103 | def cdf(self, value):
104 | if value <= self.lower_bound:
105 | return 0.0
106 | if value >= self.upper_bound:
107 | return 1.0
108 | unnormalized_cdf = self.base_dist.cdf(value) - self.base_dist.cdf(
109 | self.lower_bound
110 | )
111 | return unnormalized_cdf / self.partition_function()
112 |
113 | def icdf(self, value):
114 | if self._validate_args:
115 | self._validate_sample(value)
116 | return self.base_dist.icdf(
117 | self.cdf(self.lower_bound) + value * self.partition_function()
118 | )
119 |
--------------------------------------------------------------------------------
/kiwi/modules/common/feedforward.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from collections import OrderedDict
18 |
19 | from torch import nn
20 |
21 |
22 | def feedforward(
23 | in_dim,
24 | n_layers,
25 | shrink=2,
26 | out_dim=None,
27 | activation=nn.Tanh,
28 | final_activation=False,
29 | dropout=0.0,
30 | ):
31 | """Constructor for FeedForward Layers"""
32 | dim = in_dim
33 | module_dict = OrderedDict()
34 | for layer_i in range(n_layers - 1):
35 | next_dim = dim // shrink
36 | module_dict['linear_{}'.format(layer_i)] = nn.Linear(dim, next_dim)
37 | module_dict['activation_{}'.format(layer_i)] = activation()
38 | module_dict['dropout_{}'.format(layer_i)] = nn.Dropout(dropout)
39 | dim = next_dim
40 | next_dim = out_dim or (dim // 2)
41 | module_dict['linear_{}'.format(n_layers - 1)] = nn.Linear(dim, next_dim)
42 | if final_activation:
43 | module_dict['activation_{}'.format(n_layers - 1)] = activation()
44 | return nn.Sequential(module_dict)
45 |
--------------------------------------------------------------------------------
/kiwi/modules/common/layer_norm.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import torch
18 | from torch import nn
19 |
20 |
21 | class TFLayerNorm(nn.Module):
22 | """Construct a layer normalization module with epsilon inside the
23 | square root (tensorflow style).
24 |
25 | This is equivalent to HuggingFace's BertLayerNorm module.
26 | """
27 |
28 | def __init__(self, hidden_size, eps=1e-6):
29 | super().__init__()
30 | self.gamma = nn.Parameter(torch.ones(hidden_size))
31 | self.beta = nn.Parameter(torch.zeros(hidden_size))
32 | self.eps = eps
33 |
34 | def forward(self, x):
35 | u = x.mean(-1, keepdim=True)
36 | s = (x - u).pow(2).mean(-1, keepdim=True)
37 | x = (x - u) / torch.sqrt(s + self.eps)
38 | return self.gamma * x + self.beta
39 |
--------------------------------------------------------------------------------
/kiwi/modules/common/positional_encoding.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import math
18 |
19 | import torch
20 | from torch import nn
21 |
22 |
23 | class PositionalEncoding(nn.Module):
24 | """Absolute positional encoding mechanism.
25 |
26 | Arguments:
27 | max_seq_len: hypothetical maximum sequence length (usually 1000).
28 | hidden_size: embeddings size.
29 | """
30 |
31 | def __init__(self, max_seq_len: int, hidden_size: int):
32 | super().__init__()
33 |
34 | position = torch.arange(0.0, max_seq_len).unsqueeze(1)
35 | neg_log_term = -math.log(10000.0) / hidden_size
36 | div_term = torch.exp(torch.arange(0.0, hidden_size, 2) * neg_log_term)
37 |
38 | pe = torch.zeros(max_seq_len, hidden_size, requires_grad=False)
39 | pe[:, 0::2] = torch.sin(position * div_term)
40 |
41 | # handle cases when hidden size is odd (cos will have one less than sin)
42 | pe_cos = torch.cos(position * div_term)
43 | if hidden_size % 2 == 1:
44 | pe_cos = pe_cos[:, :-1]
45 | pe[:, 1::2] = pe_cos
46 |
47 | pe = pe.unsqueeze(0) # add batch dimension
48 | self.register_buffer('pe', pe)
49 | self.hidden_size = hidden_size
50 |
51 | def forward(self, emb):
52 | # self.pe = self.pe.to(emb.device)
53 | assert emb.shape[1] <= self.pe.shape[1]
54 | return emb + self.pe[:, : emb.shape[1]]
55 |
56 |
57 | if __name__ == '__main__':
58 | from matplotlib import pyplot as plt
59 | import numpy as np
60 |
61 | batch_size = 8
62 | vocab_size = 1000
63 | emb_size = 20
64 | seq_len = 100
65 | max_seq_len = 5000
66 | d_i, d_j = 4, 10
67 |
68 | x_emb = torch.randint(vocab_size, size=(batch_size, seq_len)).long()
69 | x_rand = torch.randn(batch_size, seq_len, emb_size)
70 | x_zero = torch.zeros(batch_size, seq_len, emb_size)
71 |
72 | embed = nn.Embedding(vocab_size, emb_size)
73 | torch.nn.init.xavier_normal_(embed.weight)
74 | pe = PositionalEncoding(max_seq_len, emb_size)
75 |
76 | x_rand = pe(x_rand)
77 | x_emb = pe(embed(x_emb)).data
78 | x_zero = pe(x_zero)
79 |
80 | plt.figure(figsize=(15, 5))
81 | plt.title('Random input')
82 | plt.plot(np.arange(seq_len), x_rand[0, :, d_i:d_j].numpy())
83 | plt.legend(['dim %d' % d for d in range(d_i, d_j)])
84 | plt.show()
85 |
86 | plt.figure(figsize=(15, 5))
87 | plt.title('Embedding input')
88 | plt.plot(np.arange(seq_len), x_emb[0, :, d_i:d_j].numpy())
89 | plt.legend(['dim %d' % d for d in range(d_i, d_j)])
90 | plt.show()
91 |
92 | plt.figure(figsize=(15, 5))
93 | plt.title('Zero input')
94 | plt.plot(np.arange(seq_len), x_zero[0, :, d_i:d_j].numpy())
95 | plt.legend(['dim %d' % d for d in range(d_i, d_j)])
96 | plt.show()
97 |
--------------------------------------------------------------------------------
/kiwi/modules/common/scorer.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from math import sqrt
18 |
19 | import torch
20 | from torch import nn
21 |
22 | from kiwi.utils.tensors import make_mergeable_tensors
23 |
24 |
25 | class Scorer(nn.Module):
26 | """Score function for attention module.
27 |
28 | Arguments:
29 | scaled: whether to scale scores by `sqrt(hidden_size)` as proposed by the
30 | "Attention is All You Need" paper.
31 | """
32 |
33 | def __init__(self, scaled: bool = True):
34 | super().__init__()
35 | self.scaled = scaled
36 |
37 | def scale(self, hidden_size: int) -> float:
38 | """Denominator for scaling the scores.
39 |
40 | Arguments:
41 | hidden_size: max hidden size between query and keys.
42 |
43 | Return:
44 | sqrt(hidden_size) if `scaled` is True, 1 otherwise.
45 | """
46 | if self.scaled:
47 | return sqrt(hidden_size)
48 | return 1
49 |
50 | def forward(
51 | self, query: torch.FloatTensor, keys: torch.FloatTensor
52 | ) -> torch.FloatTensor:
53 | """Compute scores for each key of size n given the queries of size m.
54 |
55 | The three dots (...) represent any other dimensions, such as the
56 | number of heads (useful if you use a multi head attention).
57 |
58 | Arguments:
59 | query: query matrix ``(bs, ..., target_len, m)``.
60 | keys: keys matrix ``(bs, ..., source_len, n)``.
61 |
62 | Return:
63 | matrix representing scores between source words and target words
64 | ``(bs, ..., target_len, source_len)``
65 | """
66 | raise NotImplementedError
67 |
68 |
69 | class MLPScorer(Scorer):
70 | """MultiLayerPerceptron Scorer with variable nb of layers and neurons."""
71 |
72 | def __init__(
73 | self, query_size, key_size, layer_sizes=None, activation=nn.Tanh, **kwargs
74 | ):
75 | super().__init__(**kwargs)
76 | if layer_sizes is None:
77 | layer_sizes = [(query_size + key_size) // 2]
78 | input_size = query_size + key_size # concatenate query and keys
79 | output_size = 1 # produce a scalar for each alignment
80 | layer_sizes = [input_size] + layer_sizes + [output_size]
81 | sizes = zip(layer_sizes[:-1], layer_sizes[1:])
82 | layers = []
83 | for n_in, n_out in sizes:
84 | layers.append(nn.Sequential(nn.Linear(n_in, n_out), activation()))
85 | self.layers = nn.ModuleList(layers)
86 |
87 | def forward(self, query, keys):
88 | x_query, x_keys = make_mergeable_tensors(query, keys)
89 | x = torch.cat((x_query, x_keys), dim=-1)
90 | for layer in self.layers:
91 | x = layer(x)
92 | return x.squeeze(-1) # remove last dimension
93 |
--------------------------------------------------------------------------------
/kiwi/modules/sentence_level_output.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from torch import nn as nn
18 |
19 | from kiwi.modules.common.distributions import TruncatedNormal
20 | from kiwi.modules.common.feedforward import feedforward
21 | from kiwi.utils.tensors import sequence_mask
22 |
23 |
24 | class SentenceFromLogits(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 | self.linear_out = nn.Linear(2, 1)
28 | nn.init.xavier_uniform_(self.linear_out.weight)
29 | nn.init.constant_(self.linear_out.bias, 0.0)
30 |
31 | self._loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
32 |
33 | def forward(self, inputs, lengths):
34 | mask = sequence_mask(lengths, max_len=inputs.size(1)).float()
35 | average = (inputs * mask[..., None]).sum(1) / lengths[:, None].float()
36 | h = self.linear_out(average)
37 | return h
38 |
39 | def loss_fn(self, predicted, target):
40 | target = target[..., None]
41 | return self._loss_fn(predicted, target)
42 |
43 |
44 | class SentenceScoreRegression(nn.Module):
45 | def __init__(
46 | self,
47 | input_size,
48 | dropout=0.0,
49 | activation=nn.Tanh,
50 | final_activation=False,
51 | num_layers=3,
52 | ):
53 | super().__init__()
54 |
55 | self.sentence_pred = feedforward(
56 | input_size,
57 | n_layers=num_layers,
58 | out_dim=1,
59 | dropout=dropout,
60 | activation=activation,
61 | final_activation=final_activation,
62 | )
63 |
64 | self.loss_fn = nn.MSELoss(reduction='sum')
65 |
66 | for p in self.parameters():
67 | if len(p.shape) > 1:
68 | nn.init.xavier_uniform_(p)
69 |
70 | def forward(self, features, batch_inputs):
71 | sentence_scores = self.sentence_pred(features).squeeze()
72 | return sentence_scores
73 |
74 |
75 | class SentenceScoreDistribution(nn.Module):
76 | def __init__(self, input_size):
77 | super().__init__()
78 |
79 | self.sentence_pred = feedforward(input_size, n_layers=3, out_dim=1)
80 | # Predict truncated Gaussian distribution
81 | self.sentence_sigma = feedforward(
82 | input_size,
83 | n_layers=3,
84 | out_dim=1,
85 | activation=nn.Sigmoid,
86 | final_activation=True,
87 | )
88 |
89 | self.loss_fn = self._loss_fn
90 |
91 | for p in self.parameters():
92 | if len(p.shape) > 1:
93 | nn.init.xavier_uniform_(p)
94 |
95 | @staticmethod
96 | def _loss_fn(predicted, target):
97 | _, (mu, sigma) = predicted
98 | # Compute log-likelihood of x given mu, sigma
99 | dist = TruncatedNormal(mu, sigma + 1e-12, 0.0, 1.0)
100 | nll = -dist.log_prob(target)
101 | return nll.sum()
102 |
103 | def forward(self, features, batch_inputs):
104 | mu = self.sentence_pred(features).squeeze()
105 | sigma = self.sentence_sigma(features).squeeze()
106 | # Compute the mean of the truncated Gaussian as sentence score prediction
107 | dist = TruncatedNormal(
108 | mu.clone().detach(), sigma.clone().detach() + 1e-12, 0.0, 1.0
109 | )
110 | sentence_scores = dist.mean
111 | return sentence_scores.squeeze(), (mu, sigma)
112 |
113 |
114 | class BinarySentenceScore(nn.Module):
115 | def __init__(self, input_size):
116 | super().__init__()
117 | self.sentence_pred = feedforward(
118 | input_size, n_layers=3, out_dim=2, activation=nn.Tanh
119 | )
120 |
121 | self.loss_fn = nn.CrossEntropyLoss(reduction='sum')
122 |
123 | for p in self.parameters():
124 | if len(p.shape) > 1:
125 | nn.init.xavier_uniform_(p)
126 |
127 | def forward(self, features, batch_inputs):
128 | return self.sentence_pred(features).squeeze()
129 |
--------------------------------------------------------------------------------
/kiwi/modules/token_embeddings.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import math
18 |
19 | from torch import nn
20 |
21 | from kiwi.data.batch import BatchedSentence
22 | from kiwi.modules.common.layer_norm import TFLayerNorm
23 | from kiwi.modules.common.positional_encoding import PositionalEncoding
24 | from kiwi.utils.io import BaseConfig
25 |
26 |
27 | class TokenEmbeddings(nn.Module):
28 | class Config(BaseConfig):
29 | dim: int = 50
30 | freeze: bool = False
31 | dropout: float = 0.0
32 | use_position_embeddings: bool = False
33 | max_position_embeddings: int = 4000
34 | sparse_embeddings: bool = False
35 | scale_embeddings: bool = False
36 | input_layer_norm: bool = False
37 |
38 | def __init__(self, num_embeddings: int, pad_idx: int, config: Config, vectors=None):
39 | """A model for embedding a single type of tokens."""
40 | super().__init__()
41 | self.pad_idx = pad_idx
42 |
43 | if vectors is not None:
44 | assert num_embeddings == vectors.size(0)
45 |
46 | self.embedding = nn.Embedding(
47 | num_embeddings=num_embeddings,
48 | embedding_dim=config.dim,
49 | padding_idx=pad_idx,
50 | sparse=config.sparse_embeddings,
51 | _weight=vectors,
52 | )
53 | else:
54 | self.embedding = nn.Embedding(
55 | num_embeddings=num_embeddings,
56 | embedding_dim=config.dim,
57 | padding_idx=pad_idx,
58 | sparse=config.sparse_embeddings,
59 | )
60 | nn.init.xavier_uniform_(self.embedding.weight)
61 |
62 | self._size = config.dim
63 | self._pe = config.max_position_embeddings
64 |
65 | if config.freeze:
66 | self.embedding.weight.requires_grad = False
67 |
68 | self.dropout = nn.Dropout(config.dropout)
69 |
70 | self.embeddings_scale_factor = 1
71 | if config.scale_embeddings:
72 | self.embeddings_scale_factor = math.sqrt(self._size)
73 |
74 | self.positional_encoding = None
75 | if config.use_position_embeddings:
76 | self.positional_encoding = PositionalEncoding(self._pe, self._size)
77 |
78 | self.layer_norm = None
79 | if config.input_layer_norm:
80 | self.layer_norm = TFLayerNorm(self._size)
81 |
82 | @property
83 | def num_embeddings(self):
84 | return self.embedding.num_embeddings
85 |
86 | def size(self):
87 | return self._size
88 |
89 | def forward(self, batch_input, *args):
90 | assert isinstance(batch_input, BatchedSentence)
91 | ids = batch_input.tensor
92 |
93 | embeddings = self.embedding(ids)
94 | embeddings = self.embeddings_scale_factor * embeddings
95 |
96 | if self.positional_encoding is not None:
97 | embeddings = self.positional_encoding(embeddings)
98 |
99 | if self.layer_norm is not None:
100 | embeddings = self.layer_norm(embeddings)
101 |
102 | embeddings = self.dropout(embeddings)
103 |
104 | return embeddings
105 |
--------------------------------------------------------------------------------
/kiwi/modules/word_level_output.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import torch
18 | from torch import nn
19 |
20 |
21 | class WordLevelOutput(nn.Module):
22 | def __init__(
23 | self,
24 | input_size,
25 | output_size,
26 | pad_idx,
27 | class_weights=None,
28 | remove_first=False,
29 | remove_last=False,
30 | ):
31 | super().__init__()
32 |
33 | self.pad_idx = pad_idx
34 |
35 | # Explicit check to avoid using 0 as False
36 | self.start_pos = None if remove_first is False or remove_first is None else 1
37 | self.stop_pos = None if remove_last is False or remove_last is None else -1
38 |
39 | self.linear = nn.Linear(input_size, output_size)
40 |
41 | self.loss_fn = nn.CrossEntropyLoss(
42 | reduction='sum', ignore_index=pad_idx, weight=class_weights
43 | )
44 |
45 | nn.init.xavier_uniform_(self.linear.weight)
46 | nn.init.constant_(self.linear.bias, 0.0)
47 |
48 | def forward(self, features_tensor, batch_inputs=None):
49 | logits = self.linear(features_tensor)
50 | logits = logits[:, self.start_pos : self.stop_pos]
51 | return logits
52 |
53 |
54 | class GapTagsOutput(WordLevelOutput):
55 | def __init__(
56 | self,
57 | input_size,
58 | output_size,
59 | pad_idx,
60 | class_weights=None,
61 | remove_first=False,
62 | remove_last=False,
63 | ):
64 | super().__init__(
65 | input_size=2 * input_size,
66 | output_size=output_size,
67 | pad_idx=pad_idx,
68 | class_weights=class_weights,
69 | remove_first=False,
70 | remove_last=False,
71 | )
72 | self.add_pad_start = 1 if remove_first is False or remove_first is None else 0
73 | self.add_pad_stop = 1 if remove_last is False or remove_last is None else 0
74 |
75 | def forward(self, features_tensor, batch_inputs=None):
76 | h_gaps = features_tensor
77 | if self.add_pad_start or self.add_pad_stop:
78 | # Pad dim=1
79 | num_of_pads = self.add_pad_start + self.add_pad_stop
80 | h_gaps = nn.functional.pad(
81 | h_gaps,
82 | pad=[0, 0] * (len(h_gaps.shape) - num_of_pads)
83 | + [self.add_pad_start, self.add_pad_stop],
84 | value=0,
85 | )
86 | h_gaps = torch.cat((h_gaps[:, :-1], h_gaps[:, 1:]), dim=-1)
87 | logits = super().forward(h_gaps, batch_inputs)
88 | return logits
89 |
--------------------------------------------------------------------------------
/kiwi/systems/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | """Use systems in a plugin way.
18 |
19 | Solution based on https://julienharbulot.com/python-dynamical-import.html.
20 | """
21 | from importlib import import_module
22 | from pathlib import Path
23 | from pkgutil import iter_modules
24 |
25 | # iterate through the modules in the current package
26 | package_dir = Path(__file__).resolve().parent
27 | for (_, module_name, _) in iter_modules([package_dir]):
28 | # import the module and iterate through its attributes
29 | module = import_module(f"{__name__}.{module_name}")
30 | # for attribute_name in dir(module):
31 | # attribute = getattr(module, attribute_name)
32 | #
33 | # try:
34 | # if isclass(attribute) and issubclass(attribute, QESystem):
35 | # # Add the class to this package's variables
36 | # globals()[attribute_name] = attribute
37 | # except TypeError:
38 | # pass
39 |
--------------------------------------------------------------------------------
/kiwi/systems/_meta_module.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import json
18 | import logging
19 | from abc import ABCMeta, abstractmethod
20 | from collections import OrderedDict
21 |
22 | import torch
23 | import torch.nn as nn
24 |
25 | import kiwi
26 | from kiwi import constants as const
27 | from kiwi.utils.io import BaseConfig, load_torch_file
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class Serializable(metaclass=ABCMeta):
33 | subclasses = {}
34 |
35 | @classmethod
36 | def register_subclass(cls, subclass):
37 | cls.subclasses[subclass.__name__] = subclass
38 | return subclass
39 |
40 | @classmethod
41 | def retrieve_subclass(cls, subclass_name):
42 | subclass = cls.subclasses.get(subclass_name)
43 | if subclass is None:
44 | raise KeyError(
45 | f'{subclass_name} is not a registered subclass of {cls.__name__}'
46 | )
47 | return subclass
48 |
49 | @classmethod
50 | def load(cls, path):
51 | model_dict = load_torch_file(path)
52 | return cls.from_dict(model_dict)
53 |
54 | def save(self, path):
55 | torch.save(self.to_dict(), path)
56 |
57 | @classmethod
58 | @abstractmethod
59 | def from_dict(cls, *args, **kwargs):
60 | pass
61 |
62 | @classmethod
63 | @abstractmethod
64 | def to_dict(cls, include_state=True):
65 | pass
66 |
67 |
68 | class MetaModule(nn.Module, Serializable, metaclass=ABCMeta):
69 | class Config(BaseConfig, metaclass=ABCMeta):
70 | pass
71 |
72 | def __init__(self, config: Config):
73 | """Base module used for several model layers and modules.
74 |
75 | Arguments:
76 | config: a ``MetaModule.Config`` object.
77 | """
78 | super().__init__()
79 |
80 | self.config = config
81 |
82 | @classmethod
83 | def from_dict(cls, module_dict, **kwargs):
84 | module_cls = cls.retrieve_subclass(module_dict['class_name'])
85 | config = module_cls.Config(**module_dict[const.CONFIG])
86 | module = module_cls(config=config, **kwargs)
87 |
88 | state_dict = module_dict.get(const.STATE_DICT)
89 | if state_dict:
90 | not_loaded_keys = module.load_state_dict(state_dict)
91 | logger.debug(f'Loaded encoder; extraneous keys: {not_loaded_keys}')
92 |
93 | return module
94 |
95 | def to_dict(self, include_state=True):
96 | module_dict = OrderedDict(
97 | {
98 | '__version__': kiwi.__version__,
99 | 'class_name': self.__class__.__name__,
100 | const.CONFIG: json.loads(self.config.json()),
101 | const.STATE_DICT: self.state_dict() if include_state else None,
102 | }
103 | )
104 | return module_dict
105 |
--------------------------------------------------------------------------------
/kiwi/systems/bert.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from typing import Any, Dict
19 |
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder
22 | from kiwi.systems.decoders.linear import LinearDecoder
23 | from kiwi.systems.encoders.bert import BertEncoder
24 | from kiwi.systems.outputs.quality_estimation import QEOutputs
25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs
26 | from kiwi.systems.qe_system import QESystem
27 | from kiwi.utils.io import BaseConfig
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class ModelConfig(BaseConfig):
33 | encoder: BertEncoder.Config = BertEncoder.Config()
34 | decoder: LinearDecoder.Config = LinearDecoder.Config()
35 | outputs: QEOutputs.Config = QEOutputs.Config()
36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
37 |
38 |
39 | @QESystem.register_subclass
40 | class Bert(QESystem):
41 | """BERT-based Predictor-Estimator model."""
42 |
43 | class Config(QESystem.Config):
44 | model: ModelConfig
45 |
46 | def __init__(
47 | self,
48 | config,
49 | data_config: WMTQEDataset.Config = None,
50 | module_dict: Dict[str, Any] = None,
51 | ):
52 | super().__init__(config, data_config=data_config)
53 |
54 | if module_dict:
55 | # Load modules and weights
56 | self._load_dict(module_dict)
57 | elif self.config.load_encoder:
58 | self._load_encoder(self.config.load_encoder)
59 | else:
60 | # Initialize data processing
61 | self.data_encoders = WMTQEDataEncoder(
62 | config=self.config.data_processing,
63 | field_encoders=BertEncoder.input_data_encoders(
64 | self.config.model.encoder
65 | ),
66 | )
67 |
68 | # Add possibly missing fields, like outputs
69 | if self.config.load_vocabs:
70 | self.data_encoders.load_vocabularies(self.config.load_vocabs)
71 | if self.train_dataset:
72 | self.data_encoders.fit_vocabularies(self.train_dataset)
73 |
74 | # Input to features
75 | if not self.encoder:
76 | self.encoder = BertEncoder(
77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder
78 | )
79 |
80 | # Features to output
81 | if not self.decoder:
82 | self.decoder = LinearDecoder(
83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder
84 | )
85 |
86 | # Output layers
87 | if not self.outputs:
88 | self.outputs = QEOutputs(
89 | inputs_dims=self.decoder.size(),
90 | vocabs=self.data_encoders.vocabularies,
91 | config=self.config.model.outputs,
92 | )
93 |
94 | if not self.tlm_outputs:
95 | self.tlm_outputs = TLMOutputs(
96 | inputs_dims=self.encoder.size(),
97 | vocabs=self.data_encoders.vocabularies,
98 | config=self.config.model.tlm_outputs,
99 | )
100 |
--------------------------------------------------------------------------------
/kiwi/systems/decoders/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/systems/decoders/linear.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from collections import OrderedDict
18 | from typing import Dict
19 |
20 | import torch
21 | from pydantic import confloat
22 | from torch import nn
23 |
24 | from kiwi import constants as const
25 | from kiwi.data.batch import MultiFieldBatch
26 | from kiwi.systems._meta_module import MetaModule
27 | from kiwi.utils.io import BaseConfig
28 |
29 |
30 | @MetaModule.register_subclass
31 | class LinearDecoder(MetaModule):
32 | class Config(BaseConfig):
33 | hidden_size: int = 250
34 | 'Size of hidden layer'
35 |
36 | dropout: confloat(ge=0.0, le=1.0) = 0.0
37 |
38 | bottleneck_size: int = 100
39 |
40 | def __init__(self, inputs_dims, config):
41 | super().__init__(config=config)
42 |
43 | self.features_dims = inputs_dims
44 |
45 | # Build Model #
46 | self.linear_outs = nn.ModuleDict()
47 | self._sizes = {}
48 | if const.TARGET in self.features_dims:
49 | self.linear_outs[const.TARGET] = nn.Sequential(
50 | nn.Linear(self.features_dims[const.TARGET], self.config.hidden_size),
51 | nn.Tanh(),
52 | )
53 | self._sizes[const.TARGET] = self.config.hidden_size
54 | if const.SOURCE in self.features_dims:
55 | self.linear_outs[const.SOURCE] = nn.Sequential(
56 | nn.Linear(self.features_dims[const.SOURCE], self.config.hidden_size),
57 | nn.Tanh(),
58 | )
59 | self._sizes[const.SOURCE] = self.config.hidden_size
60 |
61 | self.dropout = nn.Dropout(self.config.dropout)
62 |
63 | for p in self.parameters():
64 | if len(p.shape) > 1:
65 | nn.init.xavier_uniform_(p)
66 |
67 | if const.TARGET_SENTENCE in self.features_dims:
68 | linear_layers = [
69 | nn.Linear(
70 | self.features_dims[const.TARGET_SENTENCE],
71 | self.config.bottleneck_size,
72 | ),
73 | nn.Tanh(),
74 | nn.Dropout(self.config.dropout),
75 | ]
76 | linear_layers.extend(
77 | [
78 | nn.Linear(self.config.bottleneck_size, self.config.hidden_size),
79 | nn.Tanh(),
80 | nn.Dropout(self.config.dropout),
81 | ]
82 | )
83 | self.linear_outs[const.TARGET_SENTENCE] = nn.Sequential(*linear_layers)
84 | self._sizes[const.TARGET_SENTENCE] = self.config.hidden_size
85 |
86 | def size(self, field=None):
87 | if field:
88 | return self._sizes[field]
89 | return self._sizes
90 |
91 | def forward(self, features: Dict[str, torch.Tensor], batch_inputs: MultiFieldBatch):
92 | output_features = OrderedDict()
93 |
94 | if const.TARGET in features:
95 | features_tensor = self.dropout(features[const.TARGET])
96 | output_features[const.TARGET] = self.linear_outs[const.TARGET](
97 | features_tensor
98 | )
99 | if const.SOURCE in features:
100 | features_tensor = self.dropout(features[const.SOURCE])
101 | output_features[const.SOURCE] = self.linear_outs[const.SOURCE](
102 | features_tensor
103 | )
104 | if const.TARGET_SENTENCE in features:
105 | features_tensor = self.linear_outs[const.TARGET_SENTENCE](
106 | features[const.TARGET_SENTENCE]
107 | )
108 | output_features[const.TARGET_SENTENCE] = features_tensor
109 |
110 | return output_features
111 |
--------------------------------------------------------------------------------
/kiwi/systems/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/systems/nuqe.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from typing import Any, Dict
19 |
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder
22 | from kiwi.systems.decoders.nuqe import NuQEDecoder
23 | from kiwi.systems.encoders.quetch import QUETCHEncoder
24 | from kiwi.systems.outputs.quality_estimation import QEOutputs
25 | from kiwi.systems.qe_system import QESystem
26 | from kiwi.utils.io import BaseConfig
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class ModelConfig(BaseConfig):
32 | encoder: QUETCHEncoder.Config
33 | decoder: NuQEDecoder.Config
34 | outputs: QEOutputs.Config
35 |
36 |
37 | @QESystem.register_subclass
38 | class NuQE(QESystem):
39 | """Neural Quality Estimation (NuQE) model for word level quality estimation."""
40 |
41 | class Config(QESystem.Config):
42 | model: ModelConfig
43 |
44 | def __init__(
45 | self,
46 | config,
47 | data_config: WMTQEDataset.Config = None,
48 | module_dict: Dict[str, Any] = None,
49 | ):
50 | super().__init__(config, data_config=data_config)
51 |
52 | if module_dict:
53 | # Load modules and weights
54 | self._load_dict(module_dict)
55 | elif self.config.load_encoder:
56 | logger.warning(
57 | f'NuQE does not support loading the encoder; ignoring option '
58 | f'`load_encoder={self.config.load_encoder}`'
59 | )
60 | else:
61 | # Initialize data processing
62 | self.data_encoders = WMTQEDataEncoder(
63 | config=self.config.data_processing,
64 | field_encoders=QUETCHEncoder.input_data_encoders(
65 | self.config.model.encoder
66 | ),
67 | )
68 |
69 | # Add possibly missing fields, like outputs
70 | if self.config.load_vocabs:
71 | self.data_encoders.load_vocabularies(self.config.load_vocabs)
72 | if self.train_dataset:
73 | self.data_encoders.fit_vocabularies(self.train_dataset)
74 |
75 | # Input to features
76 | if not self.encoder:
77 | self.encoder = QUETCHEncoder(
78 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder
79 | )
80 |
81 | # Features to output
82 | if not self.decoder:
83 | self.decoder = NuQEDecoder(
84 | inputs_dims=self.encoder.size(), config=self.config.model.decoder
85 | )
86 |
87 | # Output layers
88 | if not self.outputs:
89 | self.outputs = QEOutputs(
90 | inputs_dims=self.decoder.size(),
91 | vocabs=self.data_encoders.vocabularies,
92 | config=self.config.model.outputs,
93 | )
94 | self.tlm_outputs = None
95 |
--------------------------------------------------------------------------------
/kiwi/systems/outputs/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/systems/predictor.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from typing import Any, Dict
19 |
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.data.encoders.parallel_data_encoder import ParallelDataEncoder
22 | from kiwi.systems.encoders.predictor import PredictorEncoder
23 | from kiwi.systems.outputs.translation_language_model import TLMOutputs
24 | from kiwi.systems.tlm_system import TLMSystem
25 | from kiwi.utils.io import BaseConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | class ModelConfig(BaseConfig):
31 | encoder: PredictorEncoder.Config = PredictorEncoder.Config()
32 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
33 |
34 |
35 | @TLMSystem.register_subclass
36 | class Predictor(TLMSystem):
37 | """Predictor TLM, used for the Predictor-Estimator QE model (proposed in 2017)."""
38 |
39 | class Config(TLMSystem.Config):
40 | model: ModelConfig
41 |
42 | def __init__(
43 | self,
44 | config: Config,
45 | data_config: WMTQEDataset.Config = None,
46 | module_dict: Dict[str, Any] = None,
47 | ):
48 | super().__init__(config, data_config=data_config)
49 |
50 | if module_dict:
51 | # Load modules and weights
52 | self._load_dict(module_dict)
53 | else:
54 | # Initialize data processing
55 | self.data_encoders = ParallelDataEncoder(
56 | config=self.config.data_processing,
57 | field_encoders=PredictorEncoder.input_data_encoders(
58 | self.config.model.encoder
59 | ),
60 | )
61 |
62 | # Add possibly missing fields, like outputs
63 | if self.config.load_vocabs:
64 | self.data_encoders.load_vocabularies(self.config.load_vocabs)
65 | if self.train_dataset:
66 | self.data_encoders.fit_vocabularies(self.train_dataset)
67 |
68 | # Input to features
69 | if not self.encoder:
70 | self.encoder = PredictorEncoder(
71 | vocabs=self.data_encoders.vocabularies,
72 | config=self.config.model.encoder,
73 | pretraining=True,
74 | )
75 |
76 | # Output
77 | if not self.tlm_outputs:
78 | self.tlm_outputs = TLMOutputs(
79 | inputs_dims=self.encoder.size(),
80 | vocabs=self.data_encoders.vocabularies,
81 | config=self.config.model.tlm_outputs,
82 | pretraining=True,
83 | )
84 |
--------------------------------------------------------------------------------
/kiwi/systems/predictor_estimator.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from typing import Any, Dict
19 |
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder
22 | from kiwi.systems.decoders.estimator import EstimatorDecoder
23 | from kiwi.systems.encoders.predictor import PredictorEncoder
24 | from kiwi.systems.outputs.quality_estimation import QEOutputs
25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs
26 | from kiwi.systems.qe_system import QESystem
27 | from kiwi.utils.io import BaseConfig
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class ModelConfig(BaseConfig):
33 | encoder: PredictorEncoder.Config = PredictorEncoder.Config()
34 | decoder: EstimatorDecoder.Config = EstimatorDecoder.Config()
35 | outputs: QEOutputs.Config = QEOutputs.Config()
36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
37 |
38 |
39 | @QESystem.register_subclass
40 | class PredictorEstimator(QESystem):
41 | """Predictor-Estimator QE model (proposed in 2017)."""
42 |
43 | class Config(QESystem.Config):
44 | model: ModelConfig
45 |
46 | def __init__(
47 | self,
48 | config,
49 | data_config: WMTQEDataset.Config = None,
50 | module_dict: Dict[str, Any] = None,
51 | ):
52 | super().__init__(config, data_config=data_config)
53 |
54 | if module_dict:
55 | # Load modules and weights
56 | self._load_dict(module_dict)
57 | elif self.config.load_encoder:
58 | self._load_encoder(self.config.load_encoder)
59 | else:
60 | # Initialize data processing
61 | self.data_encoders = WMTQEDataEncoder(
62 | config=self.config.data_processing,
63 | field_encoders=PredictorEncoder.input_data_encoders(
64 | self.config.model.encoder
65 | ),
66 | )
67 |
68 | # Add possibly missing fields, like outputs
69 | if self.config.load_vocabs:
70 | self.data_encoders.load_vocabularies(self.config.load_vocabs)
71 | if self.train_dataset:
72 | self.data_encoders.fit_vocabularies(self.train_dataset)
73 |
74 | # Input to features
75 | if not self.encoder:
76 | self.encoder = PredictorEncoder(
77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder
78 | )
79 |
80 | # Features to output
81 | if not self.decoder:
82 | self.decoder = EstimatorDecoder(
83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder
84 | )
85 |
86 | # Output layers
87 | if not self.outputs:
88 | self.outputs = QEOutputs(
89 | inputs_dims=self.decoder.size(),
90 | vocabs=self.data_encoders.vocabularies,
91 | config=self.config.model.outputs,
92 | )
93 |
94 | if not self.tlm_outputs:
95 | self.tlm_outputs = TLMOutputs(
96 | inputs_dims=self.encoder.size(),
97 | vocabs=self.data_encoders.vocabularies,
98 | config=self.config.model.tlm_outputs,
99 | )
100 |
--------------------------------------------------------------------------------
/kiwi/systems/xlm.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from typing import Any, Dict
19 |
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder
22 | from kiwi.systems.decoders.linear import LinearDecoder
23 | from kiwi.systems.encoders.xlm import XLMEncoder
24 | from kiwi.systems.outputs.quality_estimation import QEOutputs
25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs
26 | from kiwi.systems.qe_system import QESystem
27 | from kiwi.utils.io import BaseConfig
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class ModelConfig(BaseConfig):
33 | encoder: XLMEncoder.Config = XLMEncoder.Config()
34 | decoder: LinearDecoder.Config = LinearDecoder.Config()
35 | outputs: QEOutputs.Config = QEOutputs.Config()
36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
37 |
38 |
39 | @QESystem.register_subclass
40 | class XLM(QESystem):
41 | """XLM-based model for word level quality estimation."""
42 |
43 | class Config(QESystem.Config):
44 | model: ModelConfig
45 |
46 | def __init__(
47 | self,
48 | config,
49 | data_config: WMTQEDataset.Config = None,
50 | module_dict: Dict[str, Any] = None,
51 | ):
52 | super().__init__(config, data_config=data_config)
53 |
54 | if module_dict:
55 | # Load modules and weights
56 | self._load_dict(module_dict)
57 | elif self.config.load_encoder:
58 | self._load_encoder(self.config.load_encoder)
59 | else:
60 | # Initialize data processing
61 | self.data_encoders = WMTQEDataEncoder(
62 | config=self.config.data_processing,
63 | field_encoders=XLMEncoder.input_data_encoders(
64 | self.config.model.encoder
65 | ),
66 | )
67 |
68 | # Add possibly missing fields, like outputs
69 | if self.config.load_vocabs:
70 | self.data_encoders.load_vocabularies(self.config.load_vocabs)
71 | if self.train_dataset:
72 | self.data_encoders.fit_vocabularies(self.train_dataset)
73 |
74 | # Input to features
75 | if not self.encoder:
76 | self.encoder = XLMEncoder(
77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder
78 | )
79 |
80 | # Features to output
81 | if not self.decoder:
82 | self.decoder = LinearDecoder(
83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder
84 | )
85 |
86 | # Output layers
87 | if not self.outputs:
88 | self.outputs = QEOutputs(
89 | inputs_dims=self.decoder.size(),
90 | vocabs=self.data_encoders.vocabularies,
91 | config=self.config.model.outputs,
92 | )
93 |
94 | if not self.tlm_outputs:
95 | self.tlm_outputs = TLMOutputs(
96 | inputs_dims=self.encoder.size(),
97 | vocabs=self.data_encoders.vocabularies,
98 | config=self.config.model.tlm_outputs,
99 | )
100 |
--------------------------------------------------------------------------------
/kiwi/systems/xlmroberta.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from typing import Any, Dict
19 |
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.data.encoders.wmt_qe_data_encoder import WMTQEDataEncoder
22 | from kiwi.systems.decoders.linear import LinearDecoder
23 | from kiwi.systems.encoders.xlmroberta import XLMRobertaEncoder
24 | from kiwi.systems.outputs.quality_estimation import QEOutputs
25 | from kiwi.systems.outputs.translation_language_model import TLMOutputs
26 | from kiwi.systems.qe_system import QESystem
27 | from kiwi.utils.io import BaseConfig
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class ModelConfig(BaseConfig):
33 | encoder: XLMRobertaEncoder.Config = XLMRobertaEncoder.Config()
34 | decoder: LinearDecoder.Config = LinearDecoder.Config()
35 | outputs: QEOutputs.Config = QEOutputs.Config()
36 | tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
37 |
38 |
39 | @QESystem.register_subclass
40 | class XLMRoberta(QESystem):
41 | """XLMRoberta-based Predictor-Estimator model"""
42 |
43 | class Config(QESystem.Config):
44 | model: ModelConfig
45 |
46 | def __init__(
47 | self,
48 | config,
49 | data_config: WMTQEDataset.Config = None,
50 | module_dict: Dict[str, Any] = None,
51 | ):
52 | super().__init__(config, data_config=data_config)
53 |
54 | if module_dict:
55 | # Load modules and weights
56 | self._load_dict(module_dict)
57 | elif self.config.load_encoder:
58 | self._load_encoder(self.config.load_encoder)
59 | else:
60 | # Initialize data processing
61 | self.data_encoders = WMTQEDataEncoder(
62 | config=self.config.data_processing,
63 | field_encoders=XLMRobertaEncoder.input_data_encoders(
64 | self.config.model.encoder
65 | ),
66 | )
67 |
68 | # Add possibly missing fields, like outputs
69 | if self.config.load_vocabs:
70 | self.data_encoders.load_vocabularies(self.config.load_vocabs)
71 | if self.train_dataset:
72 | self.data_encoders.fit_vocabularies(self.train_dataset)
73 |
74 | # Input to features
75 | if not self.encoder:
76 | self.encoder = XLMRobertaEncoder(
77 | vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder
78 | )
79 |
80 | # Features to output
81 | if not self.decoder:
82 | self.decoder = LinearDecoder(
83 | inputs_dims=self.encoder.size(), config=self.config.model.decoder
84 | )
85 |
86 | # Output layers
87 | if not self.outputs:
88 | self.outputs = QEOutputs(
89 | inputs_dims=self.decoder.size(),
90 | vocabs=self.data_encoders.vocabularies,
91 | config=self.config.model.outputs,
92 | )
93 |
94 | if not self.tlm_outputs:
95 | self.tlm_outputs = TLMOutputs(
96 | inputs_dims=self.encoder.size(),
97 | vocabs=self.data_encoders.vocabularies,
98 | config=self.config.model.tlm_outputs,
99 | )
100 |
--------------------------------------------------------------------------------
/kiwi/training/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/training/callbacks.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | import textwrap
19 |
20 | import numpy as np
21 | from pytorch_lightning import Callback
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class BestMetricsInfo(Callback):
27 | """Class for logging current training metrics along with the best so far."""
28 |
29 | def __init__(
30 | self,
31 | monitor: str = 'val_loss',
32 | min_delta: float = 0.0,
33 | verbose: bool = True,
34 | mode: str = 'auto',
35 | ):
36 | super().__init__()
37 |
38 | self.monitor = monitor
39 | self.min_delta = min_delta
40 | self.verbose = verbose
41 |
42 | mode_dict = {
43 | 'min': np.less,
44 | 'max': np.greater,
45 | 'auto': np.greater if 'acc' in self.monitor else np.less,
46 | }
47 |
48 | if mode not in mode_dict:
49 | logger.info(
50 | f'BestMetricsInfo mode {mode} is unknown, fallback to auto mode.'
51 | )
52 | mode = 'auto'
53 |
54 | self.monitor_op = mode_dict[mode]
55 | self.min_delta *= 1 if self.monitor_op == np.greater else -1
56 |
57 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf
58 | self.best_epoch = -1
59 | self.best_metrics = {}
60 |
61 | def on_train_begin(self, trainer, pl_module):
62 | # Allow instances to be re-used
63 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf
64 | self.best_epoch = -1
65 | self.best_metrics = {}
66 |
67 | def on_train_end(self, trainer, pl_module):
68 | if self.best_epoch > 0 and self.verbose > 0:
69 | metrics_message = textwrap.fill(
70 | ', '.join(
71 | [
72 | '{}: {:0.4f}'.format(k, v)
73 | for k, v in self.best_metrics.items()
74 | if k.startswith('val_')
75 | ]
76 | ),
77 | width=80,
78 | initial_indent='\t',
79 | subsequent_indent='\t',
80 | )
81 | best_path = trainer.checkpoint_callback.best_model_path
82 | if not best_path:
83 | best_path = (
84 | "model was not saved; check flags in Trainer if this is not "
85 | "expected"
86 | )
87 | logger.info(
88 | f'Epoch {self.best_epoch} had the best validation metric:\n'
89 | f'{metrics_message} \n'
90 | f'\t({best_path})\n'
91 | )
92 |
93 | def on_validation_end(self, trainer, pl_module):
94 | metrics = trainer.callback_metrics
95 |
96 | current = metrics.get(self.monitor)
97 | if self.monitor_op(current - self.min_delta, self.best):
98 | self.best = current
99 | self.best_epoch = trainer.current_epoch
100 | self.best_metrics = metrics.copy() # Copy or it gets overwritten
101 | if self.verbose > 0:
102 | logger.info('Best validation so far.')
103 | else:
104 | metrics_message = textwrap.fill(
105 | ', '.join(
106 | [
107 | f'{k}: {v:0.4f}'
108 | for k, v in self.best_metrics.items()
109 | if k.startswith('val_')
110 | ]
111 | ),
112 | width=80,
113 | initial_indent='\t',
114 | subsequent_indent='\t',
115 | )
116 | logger.info(
117 | f'Best validation so far was in epoch {self.best_epoch}:\n'
118 | f'{metrics_message} \n'
119 | )
120 |
--------------------------------------------------------------------------------
/kiwi/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
--------------------------------------------------------------------------------
/kiwi/utils/data_structures.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | from collections import OrderedDict
18 |
19 | from kiwi import constants as const
20 |
21 |
22 | class DefaultFrozenDict(OrderedDict):
23 | def __init__(self, mapping=None, default_key=const.UNK):
24 | if mapping is None:
25 | super().__init__()
26 | else:
27 | super().__init__(mapping)
28 | self._default_key = default_key
29 |
30 | def __getitem__(self, k):
31 | default_id = self.get(self._default_key)
32 | item = self.get(k, default_id)
33 | if item is None:
34 | raise KeyError(
35 | f"'{k}' (and default '{self._default_key}' not found either)"
36 | )
37 | return item
38 |
--------------------------------------------------------------------------------
/kiwi/utils/io.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import logging
18 | from pathlib import Path
19 |
20 | import torch
21 | from pydantic import BaseModel, Extra
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class BaseConfig(BaseModel):
27 | """Base class for all pydantic configs. Used to configure base behaviour of configs.
28 | """
29 |
30 | class Config:
31 | # Throws an error whenever an extra key is provided, effectively making parsing,
32 | # strict
33 | extra = Extra.forbid
34 |
35 |
36 | def default_map_location(storage, loc):
37 | return storage
38 |
39 |
40 | def load_torch_file(file_path, map_location=None):
41 | file_path = Path(file_path)
42 | if not file_path.exists():
43 | raise ValueError(f'Torch file not found: {file_path}')
44 |
45 | try:
46 | if map_location is None:
47 | map_location = default_map_location
48 | file_dict = torch.load(file_path, map_location=map_location)
49 | except ModuleNotFoundError as e:
50 | # Caused, e.g., by moving the Vocabulary or DefaultFrozenDict classes
51 | logger.info(
52 | 'Trying to load a slightly outdated file and encountered an issue when '
53 | 'unpickling; trying to work around it.'
54 | )
55 | if e.name == 'kiwi.data.utils':
56 | import sys
57 | from kiwi.utils import data_structures
58 |
59 | sys.modules['kiwi.data.utils'] = data_structures
60 | file_dict = torch.load(file_path, map_location=map_location)
61 | del sys.modules['kiwi.data.utils']
62 | elif e.name == 'torchtext':
63 | import sys
64 | from kiwi.data import vocabulary
65 |
66 | vocabulary.Vocab = vocabulary.Vocabulary
67 | sys.modules['torchtext'] = ''
68 | sys.modules['torchtext.vocab'] = vocabulary
69 | file_dict = torch.load(file_path, map_location=map_location)
70 | del sys.modules['torchtext.vocab']
71 | del sys.modules['torchtext']
72 | else:
73 | raise e
74 |
75 | return file_dict
76 |
77 |
78 | def save_file(file_path, data, token_sep=' ', example_sep='\n'):
79 | if data and isinstance(data[0], list):
80 | data = [token_sep.join(map(str, sentence)) for sentence in data]
81 | else:
82 | data = map(str, data)
83 | example_str = example_sep.join(data) + '\n'
84 | Path(file_path).write_text(example_str)
85 |
86 |
87 | def save_predicted_probabilities(directory, predictions, prefix=''):
88 | directory = Path(directory)
89 | directory.mkdir(parents=True, exist_ok=True)
90 | for key, preds in predictions.items():
91 | if prefix:
92 | key = f'{prefix}.{key}'
93 | output_path = Path(directory, key)
94 | logger.info(f'Saving {key} predictions to {output_path}')
95 | save_file(output_path, preds, token_sep=' ', example_sep='\n')
96 |
97 |
98 | def read_file(path):
99 | """Read a file into a list of lists of words."""
100 | with Path(path).open('r', encoding='utf8') as f:
101 | return [[token for token in line.strip().split()] for line in f]
102 |
103 |
104 | def target_gaps_to_target(batch):
105 | """Extract target tags from wmt18 format file."""
106 | return batch[1::2]
107 |
108 |
109 | def target_gaps_to_gaps(batch):
110 | """Extract gap tags from wmt18 format file."""
111 | return batch[::2]
112 |
113 |
114 | def generate_slug(text, delimiter="-"):
115 | """Convert text to a normalized "slug" without whitespace.
116 |
117 | Borrowed from the nice https://humanfriendly.readthedocs.io, by Peter Odding.
118 |
119 | Arguments:
120 | text: the original text, for example ``Some Random Text!``.
121 | delimiter: the delimiter to use for separating words
122 | (defaults to the ``-`` character).
123 |
124 | Return:
125 | the slug text, for example ``some-random-text``.
126 |
127 | Raise:
128 | :exc:`~exceptions.ValueError` when the provided text is nonempty but results
129 | in an empty slug.
130 | """
131 | import re
132 |
133 | slug = text.lower()
134 | escaped = delimiter.replace("\\", "\\\\")
135 | slug = re.sub("[^a-z0-9]+", escaped, slug)
136 | slug = slug.strip(delimiter)
137 | if text and not slug:
138 | msg = "The provided text %r results in an empty slug!"
139 | raise ValueError(format(msg, text))
140 | return slug
141 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Configuration file as per PEP 518
2 | # https://www.python.org/dev/peps/pep-0518/
3 |
4 | [tool.poetry]
5 | name = "openkiwi"
6 | version = "2.1.0"
7 | description = "Machine Translation Quality Estimation Toolkit"
8 | authors = ["AI Research, Unbabel "]
9 | license = "AGPL-3.0"
10 | readme = 'README.md'
11 | homepage = 'https://github.com/Unbabel/OpenKiwi'
12 | repository = 'https://github.com/Unbabel/OpenKiwi'
13 | documentation = 'https://unbabel.github.io/OpenKiwi'
14 | keywords = ['OpenKiwi', 'Quality Estimation', 'Machine Translation', 'Unbabel']
15 | classifiers = [
16 | 'Development Status :: 4 - Beta',
17 | 'Environment :: Console',
18 | 'Intended Audience :: Science/Research',
19 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
20 | ]
21 | packages = [
22 | {include = "kiwi"},
23 | ]
24 | include = ['pyproject.toml', 'CHANGELOG.md', 'LICENSE', 'CONTRIBUTING.md']
25 |
26 |
27 | [tool.poetry.scripts]
28 | kiwi = 'kiwi.__main__:main'
29 |
30 |
31 | [tool.poetry.dependencies]
32 | python = "^3.7"
33 | torch = ">=1.4.0, <1.7.0"
34 | tqdm = "^4.29"
35 | numpy = "^1.18"
36 | more-itertools = "^8.0.0"
37 | scipy = "^1.2"
38 | pyyaml = "^5.1.2"
39 | pytorch-nlp = "^0.5.0"
40 | transformers = "^3.0.2"
41 | pydantic = "^1.5"
42 | docopt = "^0.6.2"
43 | omegaconf = "^1.4.1"
44 | typing-extensions = "^3.7.4"
45 | hydra-core = "^0.11.3"
46 | pytorch-lightning = "^0.8.4"
47 | mlflow = {version = "^1.11.0", optional = true, extras = ["mlflow"]}
48 | optuna = {version = "^2.2.0", optional = true, extras = ["search"]}
49 | plotly = {version = "^4.11.0", optional = true, extras = ["search"]}
50 | sklearn = {version = "^0.0", optional = true, extras = ["search"]}
51 |
52 | [tool.poetry.dev-dependencies]
53 | tox = "^3.7"
54 | pytest = "^4.1"
55 | flake8 = "^3.8"
56 | isort = "^4.3"
57 | black = {version = "^19.10-beta.0",allow-prereleases = true}
58 | pytest-cov = "^2.8.1"
59 | pytest-sugar = "^0.9.3"
60 | sphinx = "^3.0"
61 | recommonmark = "^0.6.0"
62 | m2r = "^0.2.1"
63 | sphinx-autodoc-typehints = "^1.10.3"
64 | sphinx-autoapi = "^1.3.0"
65 | sphinx-paramlinks = "^0.4.1"
66 | pydata-sphinx-theme = "^0.2.2"
67 |
68 | [tool.poetry.extras]
69 | mlflow = ["mlflow"]
70 | search = ["optuna", "plotly", "sklearn"]
71 |
72 | [tool.black]
73 | skip-string-normalization = true # Don't switch to double quotes
74 | exclude = '''
75 | /(
76 | \.git
77 | | \.tox
78 | | \.venv
79 | | build
80 | | dist
81 | )/
82 | '''
83 |
84 | [tool.isort]
85 | multi_line_output = 3
86 | include_trailing_comma = true
87 | force_grid_wrap = 0
88 | use_parentheses = true
89 | line_length = 88
90 |
91 | [build-system]
92 | requires = ["poetry>=1.1.0"]
93 | build-backend = "poetry.masonry.api"
94 |
--------------------------------------------------------------------------------
/scripts/merge_target_and_gaps_preds.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 |
18 | import argparse
19 | from pathlib import Path
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--target-pred', help='Target predictions', type=str)
25 | parser.add_argument('--gaps-pred', help='Gaps predictions', type=str)
26 | parser.add_argument('--output', help='Path to output', type=str)
27 | return parser.parse_args()
28 |
29 |
30 | def main(args):
31 | output_file_path = Path(args.output)
32 | if output_file_path.exists() and output_file_path.is_dir():
33 | output_file_path = Path(output_file_path, 'predicted.prob')
34 | print('Output is a directory, saving to: {}'.format(output_file_path))
35 | elif not output_file_path.exists():
36 | if not output_file_path.parent.exists():
37 | output_file_path.parent.mkdir(parents=True)
38 | f = output_file_path.open('w', encoding='utf8')
39 |
40 | with open(args.target_pred) as f_target, open(args.gaps_pred) as f_gaps:
41 | for line_target, line_gaps in zip(f_target, f_gaps):
42 | try:
43 | # labels are probs
44 | pred_target = list(map(float, line_target.split()))
45 | pred_gaps = list(map(float, line_gaps.split()))
46 | except ValueError:
47 | # labels are and tags
48 | pred_target = line_target.split()
49 | pred_gaps = line_gaps.split()
50 | new_preds = []
51 | for i in range(len(pred_gaps)):
52 | new_preds.append(str(pred_gaps[i]))
53 | if i < len(pred_target):
54 | new_preds.append(str(pred_target[i]))
55 | f.write(' '.join(new_preds) + '\n')
56 | f.close()
57 |
58 |
59 | if __name__ == '__main__':
60 | args = parse_args()
61 | main(args)
62 |
--------------------------------------------------------------------------------
/tests/mocks/mock_vocab.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | class MockVocab:
18 | vectors = None
19 |
20 | def __init__(self, dictionary):
21 | self.stoi = dictionary
22 | self.itos = {idx: token for token, idx in self.stoi.items()}
23 | self.length = max(self.itos.keys()) + 1
24 |
25 | def token_to_id(self, token):
26 | if token in self.stoi:
27 | return self.stoi[token]
28 | else:
29 | raise KeyError
30 |
31 | def __len__(self):
32 | return self.length
33 |
--------------------------------------------------------------------------------
/tests/mocks/simple_model.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import torch
18 |
19 | from kiwi import constants as const
20 | from kiwi.metrics import LogMetric
21 | from kiwi.models.model import Model
22 | from kiwi.systems.decoders.quetch_deprecated import QUETCH
23 | from kiwi.systems.encoders.quetch import QUETCHEncoder
24 |
25 |
26 | @Model.register_subclass
27 | class SimpleModel(Model):
28 |
29 | target = const.TARGET_TAGS
30 | fieldset = QUETCH.fieldset
31 | metrics_ordering = QUETCH.metrics_ordering
32 |
33 | @staticmethod
34 | def default_features_embedder_class():
35 | return QUETCHEncoder
36 |
37 | def _build_output_layer(self, vocabs, features_dim, config):
38 | self.output_layer = torch.nn.Sequential(
39 | torch.nn.Linear(features_dim, 1), torch.nn.Sigmoid()
40 | )
41 |
42 | def forward(self, batch, *args, **kwargs):
43 | field_embeddings = {
44 | field: self.field_embedders[field](batch[field])
45 | for field in (const.SOURCE, const.TARGET)
46 | }
47 | feature_embeddings = self.encoder(field_embeddings, batch)
48 | output = self.output_layer(feature_embeddings).squeeze()
49 | return {SimpleModel.target: output}
50 |
51 | def loss(self, model_out, batch):
52 | prediction = model_out[SimpleModel.target]
53 | target = getattr(batch, SimpleModel.target).float()
54 | return {const.LOSS: ((prediction - target) ** 2).mean()}
55 |
56 | def predict(self, batch, *args, **kwargs):
57 | predictions = self.forward(batch)
58 | return {SimpleModel.target: predictions[SimpleModel.target].tolist()}
59 |
60 | def metrics(self):
61 | return (LogMetric(log_targets=[(const.LOSS, const.LOSS)]),)
62 |
--------------------------------------------------------------------------------
/tests/test_bert.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import shutil
18 |
19 | import pytest
20 | import yaml
21 | from transformers import BertConfig, BertModel
22 | from transformers.tokenization_bert import VOCAB_FILES_NAMES
23 |
24 | from conftest import check_computation
25 | from kiwi import constants as const
26 |
27 | bert_yaml = """
28 | class_name: Bert
29 |
30 | num_data_workers: 0
31 | batch_size:
32 | train: 32
33 | valid: 32
34 |
35 | model:
36 | encoder:
37 | model_name: None
38 | interleave_input: false
39 | freeze: false
40 | use_mismatch_features: false
41 | use_predictor_features: false
42 |
43 | decoder:
44 | hidden_size: 16
45 | dropout: 0.0
46 |
47 | outputs:
48 | word_level:
49 | target: true
50 | gaps: true
51 | source: true
52 | class_weights:
53 | target_tags:
54 | BAD: 3.0
55 | gap_tags:
56 | BAD: 5.0
57 | source_tags:
58 | BAD: 3.0
59 | sentence_level:
60 | hter: true
61 | use_distribution: true
62 | binary: false
63 | sentence_loss_weight: 1
64 |
65 | tlm_outputs:
66 | fine_tune: false
67 |
68 | optimizer:
69 | class_name: noam
70 | learning_rate: 0.00001
71 | warmup_steps: 1000
72 | training_steps: 12000
73 |
74 | data_processing:
75 | share_input_fields_encoders: true
76 |
77 | """
78 |
79 |
80 | @pytest.fixture
81 | def bert_config_dict():
82 | return yaml.unsafe_load(bert_yaml)
83 |
84 |
85 | @pytest.fixture(scope='session')
86 | def bert_model_dir(model_dir):
87 | return model_dir.joinpath('bert/')
88 |
89 |
90 | @pytest.fixture(scope='function')
91 | def bert_model():
92 | config = BertConfig(
93 | vocab_size=107,
94 | hidden_size=32,
95 | num_hidden_layers=5,
96 | num_attention_heads=4,
97 | intermediate_size=37,
98 | hidden_act='gelu',
99 | hidden_dropout_prob=0.1,
100 | attention_probs_dropout_prob=0.1,
101 | max_position_embeddings=512,
102 | type_vocab_size=16,
103 | is_decoder=False,
104 | initializer_range=0.02,
105 | )
106 | return BertModel(config=config)
107 |
108 |
109 | def test_computation_target(
110 | tmp_path,
111 | bert_model,
112 | bert_model_dir,
113 | bert_config_dict,
114 | train_config,
115 | data_config,
116 | big_atol,
117 | ):
118 | train_config['run']['output_dir'] = tmp_path
119 | train_config['data'] = data_config
120 | train_config['system'] = bert_config_dict
121 |
122 | shutil.copy2(bert_model_dir / VOCAB_FILES_NAMES['vocab_file'], tmp_path)
123 | bert_model.save_pretrained(tmp_path)
124 | train_config['system']['model']['encoder']['model_name'] = str(tmp_path)
125 |
126 | check_computation(
127 | train_config,
128 | tmp_path,
129 | output_name=const.TARGET_TAGS,
130 | expected_avg_probs=0.550805,
131 | atol=big_atol,
132 | )
133 |
134 |
135 | if __name__ == '__main__': # pragma: no cover
136 | pytest.main([__file__]) # pragma: no cover
137 |
--------------------------------------------------------------------------------
/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import pytest
18 |
19 | from kiwi.lib.utils import arguments_to_configuration, file_to_configuration
20 |
21 |
22 | @pytest.fixture
23 | def config_path(tmp_path):
24 | return tmp_path / 'config.yaml'
25 |
26 |
27 | @pytest.fixture
28 | def config():
29 | return """
30 | class_name: Bert
31 |
32 | num_data_workers: 0
33 | batch_size:
34 | train: 32
35 | valid: 32
36 |
37 | model:
38 | encoder:
39 | model_name: None
40 | interleave_input: false
41 | freeze: false
42 | use_mismatch_features: false
43 | use_predictor_features: false
44 | encode_source: false
45 |
46 | """
47 |
48 |
49 | def test_file_reading_to_configuration(config, config_path):
50 | """Tests if files are being correctly handed to hydra for
51 | composition.
52 | """
53 | config_path.write_text(config)
54 | config_dict = file_to_configuration(config_path)
55 | assert isinstance(config_dict, dict)
56 | assert 'num_data_workers' in config_dict
57 |
58 |
59 | def test_hydra_state_hadnling(config, config_path):
60 | """Tests if hydra global state is being handled correctly.
61 | If not, kiwi will not allow a config to be ran twice."""
62 |
63 | config_path.write_text(config)
64 | config_dict = file_to_configuration(config_path)
65 | config_dict = file_to_configuration(config_path)
66 |
67 |
68 | def test_arguments_to_configuration(config, config_path):
69 | """Tests if configuration handling and overwrites are working"""
70 |
71 | config_path.write_text(config)
72 | config_dict = arguments_to_configuration(
73 | {'CONFIG_FILE': config_path, 'OVERWRITES': ['class_name=XLMR']}
74 | )
75 |
76 | assert 'num_data_workers' in config_dict
77 | assert 'model' in config_dict
78 | assert 'freeze' in config_dict['model']['encoder']
79 | assert config_dict['class_name'] == 'XLMR'
80 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | #
18 | # This program is free software: you can redistribute it and/or modify
19 | # it under the terms of the GNU Affero General Public License as published
20 | # by the Free Software Foundation, either version 3 of the License, or
21 | # (at your option) any later version.
22 | #
23 | # This program is distributed in the hope that it will be useful,
24 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
25 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 | # GNU Affero General Public License for more details.
27 | #
28 | # You should have received a copy of the GNU Affero General Public License
29 | # along with this program. If not, see .
30 | #
31 |
32 | import numpy as np
33 | import pytest
34 |
35 | from kiwi.metrics.functions import (
36 | confusion_matrix,
37 | f1_product,
38 | fscore,
39 | matthews_correlation_coefficient,
40 | )
41 |
42 |
43 | @pytest.fixture
44 | def labels():
45 | n_class = 2
46 | y_gold = np.array(
47 | [
48 | np.array([1, 1, 0, 1]),
49 | np.array([1, 1, 0, 1, 1, 1, 1, 0]),
50 | np.array([1, 1, 1, 0, 1, 1, 0, 0, 0, 1]),
51 | ]
52 | )
53 | y_hat = np.array(
54 | [
55 | np.array([1, 1, 0, 0]),
56 | np.array([1, 1, 0, 1, 1, 1, 1, 0]),
57 | np.array([1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
58 | ]
59 | )
60 | return y_gold, y_hat, n_class
61 |
62 |
63 | def test_fscore(labels, atol):
64 | y_gold, y_hat, n_class = labels
65 | cnfm = confusion_matrix(y_hat, y_gold, n_class)
66 | f1_orig_prod_micro = f1_product(y_hat, y_gold)
67 | f1_prod_macro = 0
68 | f1_orig_prod_macro = 0
69 | tp, tn, fp, fn = 0, 0, 0, 0
70 | for ys_hat, ys_gold in zip(y_hat, y_gold):
71 | ctp = np.sum((ys_hat == 1) & (ys_gold == 1))
72 | ctn = np.sum((ys_hat == 0) & (ys_gold == 0))
73 | cfp = np.sum((ys_hat == 1) & (ys_gold == 0))
74 | cfn = np.sum((ys_hat == 0) & (ys_gold == 1))
75 | tp += ctp
76 | tn += ctn
77 | fp += cfp
78 | fn += cfn
79 | f_ok = fscore(ctp, cfp, cfn)
80 | f_bad = fscore(ctn, cfp, cfn)
81 | f1_prod_macro += f_ok * f_bad
82 | f1_orig_prod_macro += f1_product(ys_hat, ys_gold)
83 |
84 | assert tn == cnfm[0, 0]
85 | assert fp == cnfm[0, 1]
86 | assert fn == cnfm[1, 0]
87 | assert tp == cnfm[1, 1]
88 |
89 | f_ok = fscore(tp, fp, fn)
90 | f_bad = fscore(tn, fp, fn)
91 | f1_prod_micro = f_ok * f_bad
92 | f1_prod_macro = f1_prod_macro / y_gold.shape[0]
93 | f1_orig_prod_macro = f1_orig_prod_macro / y_gold.shape[0]
94 |
95 | np.testing.assert_allclose(f1_prod_micro, f1_orig_prod_micro, atol=atol)
96 | np.testing.assert_allclose(f1_prod_macro, f1_orig_prod_macro, atol=atol)
97 |
98 |
99 | def test_matthews(labels, atol):
100 | y_gold, y_hat, n_class = labels
101 | matthews = matthews_correlation_coefficient(y_hat, y_gold)
102 | np.testing.assert_allclose(matthews, 0.70082555, atol=atol)
103 | n1 = matthews_correlation_coefficient(y_hat, 1 - y_gold)
104 | n2 = matthews_correlation_coefficient(1 - y_hat, y_gold)
105 | nn = matthews_correlation_coefficient(1 - y_hat, 1 - y_gold)
106 | assert n1 == n2
107 | assert nn == matthews
108 | assert n1 == -matthews
109 |
110 |
111 | if __name__ == '__main__': # pragma: no cover
112 | pytest.main([__file__]) # pragma: no cover
113 |
--------------------------------------------------------------------------------
/tests/test_nuqe.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import numpy as np
18 | import pytest
19 |
20 | from conftest import check_computation
21 | from kiwi import constants as const
22 | from kiwi.lib import train
23 | from kiwi.lib.utils import save_config_to_file
24 |
25 |
26 | @pytest.fixture
27 | def output_source_config(nuqe_config_dict):
28 | nuqe_config_dict['model']['outputs']['word_level']['source'] = True
29 | return nuqe_config_dict
30 |
31 |
32 | @pytest.fixture
33 | def output_gaps_config(nuqe_config_dict):
34 | nuqe_config_dict['model']['outputs']['word_level']['gaps'] = True
35 | return nuqe_config_dict
36 |
37 |
38 | @pytest.fixture
39 | def output_targetgaps_config(nuqe_config_dict):
40 | nuqe_config_dict['model']['outputs']['word_level']['target'] = True
41 | nuqe_config_dict['model']['outputs']['word_level']['gaps'] = True
42 | return nuqe_config_dict
43 |
44 |
45 | def test_computation_target(
46 | tmp_path, output_target_config, train_config, data_config, big_atol
47 | ):
48 | train_config['data'] = data_config
49 | train_config['system'] = output_target_config
50 | check_computation(
51 | train_config,
52 | tmp_path,
53 | output_name=const.TARGET_TAGS,
54 | expected_avg_probs=0.498354,
55 | atol=big_atol,
56 | )
57 |
58 |
59 | def test_computation_gaps(
60 | tmp_path, output_gaps_config, train_config, data_config, atol
61 | ):
62 | train_config['data'] = data_config
63 | train_config['system'] = output_gaps_config
64 | check_computation(
65 | train_config,
66 | tmp_path,
67 | output_name=const.GAP_TAGS,
68 | expected_avg_probs=0.316064,
69 | atol=atol,
70 | )
71 |
72 |
73 | def test_computation_source(
74 | tmp_path, output_source_config, train_config, data_config, big_atol
75 | ):
76 | train_config['data'] = data_config
77 | train_config['system'] = output_source_config
78 | check_computation(
79 | train_config,
80 | tmp_path,
81 | output_name=const.SOURCE_TAGS,
82 | expected_avg_probs=0.486522,
83 | atol=big_atol,
84 | )
85 |
86 |
87 | def test_computation_targetgaps(
88 | tmp_path, output_targetgaps_config, train_config, data_config, big_atol
89 | ):
90 | train_config['data'] = data_config
91 | train_config['system'] = output_targetgaps_config
92 | check_computation(
93 | train_config,
94 | tmp_path,
95 | output_name=const.TARGET_TAGS,
96 | expected_avg_probs=0.507699,
97 | atol=big_atol,
98 | )
99 |
100 |
101 | def test_api(tmp_path, output_target_config, train_config, data_config, big_atol):
102 | from kiwi import train_from_file, load_system
103 |
104 | train_config['data'] = data_config
105 | train_config['system'] = output_target_config
106 |
107 | config_file = tmp_path / 'config.yaml'
108 | save_config_to_file(train.Configuration(**train_config), config_file)
109 |
110 | train_run_info = train_from_file(config_file)
111 |
112 | runner = load_system(train_run_info.best_model_path)
113 |
114 | source = open(data_config['test']['input']['source']).readlines()
115 | target = open(data_config['test']['input']['target']).readlines()
116 | alignments = open(data_config['test']['input']['alignments']).readlines()
117 |
118 | predictions = runner.predict(
119 | source=source,
120 | target=target,
121 | alignments=alignments,
122 | batch_size=train_config['system']['batch_size'],
123 | )
124 |
125 | target_tags_probabilities = predictions.target_tags_BAD_probabilities
126 | avg_of_avgs = np.mean(list(map(np.mean, target_tags_probabilities)))
127 | max_prob = max(map(max, target_tags_probabilities))
128 | min_prob = min(map(min, target_tags_probabilities))
129 | np.testing.assert_allclose(avg_of_avgs, 0.498287, atol=big_atol)
130 | assert 0 <= min_prob <= avg_of_avgs <= max_prob <= 1
131 |
132 | assert len(predictions.target_tags_labels) == len(target)
133 |
134 |
135 | if __name__ == '__main__': # pragma: no cover
136 | pytest.main([__file__]) # pragma: no cover
137 |
--------------------------------------------------------------------------------
/tests/test_predict_and_evaluate.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import pytest
18 |
19 | from kiwi import load_system
20 | from kiwi.data.datasets.wmt_qe_dataset import WMTQEDataset
21 | from kiwi.lib import predict
22 |
23 |
24 | def test_predicting_and_evaluating(tmp_path, data_config, model_dir, atol):
25 | load_model = model_dir / 'nuqe.ckpt'
26 |
27 | predicting_config = predict.Configuration(
28 | run=dict(seed=42, output_dir=tmp_path, predict_on_data_partition='valid'),
29 | data=WMTQEDataset.Config(**data_config),
30 | system=dict(load=load_model),
31 | use_gpu=False,
32 | )
33 |
34 | predictions, metrics = predict.run(predicting_config, tmp_path)
35 | assert set(predictions.keys()) == {
36 | 'target_tags',
37 | 'target_tags_labels',
38 | 'gap_tags',
39 | 'gap_tags_labels',
40 | 'sentence_scores',
41 | 'sentence_scores_extras',
42 | 'targetgaps_tags',
43 | 'targetgaps_tags_labels',
44 | }
45 | assert (
46 | abs(metrics.word_scores['targetgaps_tags']['F1_Mult'][0] - 0.1937725747540829)
47 | < 0.1
48 | )
49 | assert (
50 | abs(metrics.word_scores['target_tags']['F1_Mult'][0] - 0.053539393196874105)
51 | < 0.1
52 | )
53 | assert abs(metrics.sentence_scores['scoring'][0][1] - -0.13380020964645983) < 0.1
54 |
55 |
56 | def test_runner(tmp_path, data_config, model_dir, atol):
57 | load_model = model_dir / 'nuqe.ckpt'
58 | runner = load_system(load_model)
59 |
60 | data_config = WMTQEDataset.Config(**data_config)
61 | dataset = WMTQEDataset.build(data_config, valid=True)
62 |
63 | predictions = runner.predict(
64 | source=dataset['source'],
65 | target=dataset['target'],
66 | alignments=dataset['alignments'],
67 | )
68 |
69 | target_lengths = [len(s.split()) for s in dataset['target']]
70 | predicted_lengths = [len(s) for s in predictions.target_tags_BAD_probabilities]
71 | predicted_labels_lengths = [len(s) for s in predictions.target_tags_labels]
72 |
73 | assert target_lengths == predicted_lengths == predicted_labels_lengths
74 |
75 | predicted_gap_lengths = [len(s) - 1 for s in predictions.gap_tags_BAD_probabilities]
76 | predicted_gap_labels_len = [len(s) - 1 for s in predictions.gap_tags_labels]
77 |
78 | assert target_lengths == predicted_gap_lengths == predicted_gap_labels_len
79 |
80 | assert len(dataset['target']) == len(predictions.sentences_hter)
81 |
82 |
83 | def test_predict_with_empty_sentences(data_config, model_dir):
84 | load_model = model_dir / 'nuqe.ckpt'
85 | runner = load_system(load_model)
86 |
87 | data_config = WMTQEDataset.Config(**data_config)
88 | dataset = WMTQEDataset.build(data_config, valid=True)
89 |
90 | test_data = dict(
91 | source=dataset['source'],
92 | target=dataset['target'],
93 | alignments=dataset['alignments'],
94 | )
95 |
96 | blank_indices = [1, 3, -1]
97 |
98 | for field in test_data:
99 | for idx in reversed(blank_indices):
100 | test_data[field].insert(idx, '')
101 |
102 | assert len(test_data['source']) == len(dataset['source']) + len(blank_indices)
103 | assert len(test_data['target']) == len(dataset['target']) + len(blank_indices)
104 |
105 | predictions = runner.predict(**test_data)
106 |
107 | assert len(predictions.target_tags_labels) == len(test_data['target'])
108 | assert len(predictions.target_tags_BAD_probabilities) == len(test_data['target'])
109 |
110 |
111 | def test_predict_with_all_empty_sentences(data_config, model_dir):
112 | load_model = model_dir / 'nuqe.ckpt'
113 | runner = load_system(load_model)
114 |
115 | # All empty
116 | test_data = dict(source=[''] * 10, target=[''] * 10, alignments=[''] * 10,)
117 | predictions = runner.predict(**test_data)
118 | assert all(prediction is None for prediction in vars(predictions).values())
119 |
120 |
121 | def test_predict_with_one_side_all_empty_sentence(data_config, model_dir):
122 | load_model = model_dir / 'nuqe.ckpt'
123 | runner = load_system(load_model)
124 |
125 | # One side all empty
126 | test_data = dict(source=['AB'] * 10, target=[''] * 10, alignments=[''] * 10,)
127 | with pytest.raises(ValueError, match='Received empty'):
128 | runner.predict(**test_data)
129 |
130 | # Other side
131 | test_data = dict(source=[''] * 10, target=['AB'] * 10, alignments=[''] * 10,)
132 | with pytest.raises(ValueError, match='Received empty'):
133 | runner.predict(**test_data)
134 |
--------------------------------------------------------------------------------
/tests/test_predictor.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import numpy as np
18 | import pytest
19 | import yaml
20 |
21 | from kiwi.lib import pretrain
22 |
23 | predictor_yaml = """
24 | class_name: Predictor
25 |
26 | num_data_workers: 0
27 | batch_size:
28 | train: 32
29 | valid: 32
30 |
31 | model:
32 | encoder:
33 | hidden_size: 400
34 | rnn_layers: 2
35 | embeddings:
36 | source:
37 | dim: 200
38 | target:
39 | dim: 200
40 | out_embeddings_dim: 200
41 | share_embeddings: false
42 | dropout: 0.5
43 | use_mismatch_features: false
44 |
45 | optimizer:
46 | class_name: adam
47 | learning_rate: 0.001
48 | learning_rate_decay: 0.6
49 | learning_rate_decay_start: 2
50 |
51 | data_processing:
52 | vocab:
53 | min_frequency: 1
54 | max_size: 60_000
55 | """
56 |
57 |
58 | @pytest.fixture
59 | def predictor_config_dict():
60 | return yaml.unsafe_load(predictor_yaml)
61 |
62 |
63 | def test_pretrain_predictor(
64 | tmp_path, predictor_config_dict, pretrain_config, data_config, extra_big_atol
65 | ):
66 | pretrain_config['run']['output_dir'] = tmp_path
67 | pretrain_config['data'] = data_config
68 | pretrain_config['system'] = predictor_config_dict
69 |
70 | train_info = pretrain.pretrain_from_configuration(pretrain_config)
71 |
72 | stats = train_info.best_metrics
73 | np.testing.assert_allclose(stats['target_PERP'], 838.528486, atol=extra_big_atol)
74 | np.testing.assert_allclose(
75 | stats['val_target_PERP'], 501.516467, atol=extra_big_atol
76 | )
77 |
78 | # Testing predictor with pickled data
79 | pretrain_config['system']['load'] = train_info.best_model_path
80 |
81 | train_info_from_loaded = pretrain.pretrain_from_configuration(pretrain_config)
82 |
83 | stats = train_info_from_loaded.best_metrics
84 | np.testing.assert_allclose(
85 | stats['target_PERP'], 166.4964834955168, atol=extra_big_atol
86 | )
87 | np.testing.assert_allclose(
88 | stats['val_target_PERP'], 333.688725, atol=extra_big_atol
89 | )
90 |
91 |
92 | if __name__ == '__main__': # pragma: no cover
93 |
94 | pytest.main([__file__]) # pragma: no cover
95 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import pytest
18 | import torch
19 |
20 | from kiwi.utils.io import generate_slug, load_torch_file
21 |
22 |
23 | def test_load_torch_file(model_dir):
24 | load_torch_file(model_dir / 'nuqe.ckpt')
25 | # There's no CUDA:
26 | # if not torch.cuda.is_available():
27 | # with pytest.raises(RuntimeError, match='No CUDA GPUs are available'):
28 | # load_torch_file(
29 | # model_dir / 'nuqe.ckpt',
30 | # map_location=lambda storage, loc: storage.cuda(0),
31 | # )
32 | # And this file does not exist:
33 | with pytest.raises(ValueError):
34 | load_torch_file(model_dir / 'nonexistent.torch')
35 |
36 |
37 | def test_generate_slug():
38 | assert generate_slug('Some Random Text!') == 'some-random-text'
39 |
--------------------------------------------------------------------------------
/tests/test_xlm.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import pytest
18 | import yaml
19 | from transformers import XLMConfig, XLMModel, XLMTokenizer
20 |
21 | from conftest import check_computation
22 | from kiwi import constants as const
23 |
24 | xlm_yaml = """
25 | class_name: XLM
26 |
27 | num_data_workers: 0
28 | batch_size:
29 | train: 32
30 | valid: 32
31 |
32 | model:
33 | encoder:
34 | model_name: None
35 | interleave_input: false
36 | freeze: false
37 |
38 | decoder:
39 | hidden_size: 16
40 | dropout: 0.0
41 |
42 | outputs:
43 | word_level:
44 | target: true
45 | gaps: true
46 | source: true
47 | class_weights:
48 | target_tags:
49 | BAD: 3.0
50 | gap_tags:
51 | BAD: 5.0
52 | source_tags:
53 | BAD: 3.0
54 | sentence_level:
55 | hter: true
56 | use_distribution: true
57 | binary: false
58 | sentence_loss_weight: 1
59 |
60 | tlm_outputs:
61 | fine_tune: false
62 |
63 | optimizer:
64 | class_name: adamw
65 | learning_rate: 0.00001
66 | warmup_steps: 0.1
67 | training_steps: 12000
68 |
69 | data_processing:
70 | share_input_fields_encoders: true
71 |
72 | """
73 |
74 |
75 | @pytest.fixture
76 | def xlm_config_dict():
77 | return yaml.unsafe_load(xlm_yaml)
78 |
79 |
80 | @pytest.fixture(scope='session')
81 | def xlm_model_dir(model_dir):
82 | return model_dir.joinpath('xlm/')
83 |
84 |
85 | @pytest.fixture(scope='function')
86 | def xlm_model():
87 | config = XLMConfig(
88 | vocab_size=93000,
89 | emb_dim=32,
90 | n_layers=5,
91 | n_heads=4,
92 | dropout=0.1,
93 | max_position_embeddings=512,
94 | lang2id={
95 | "ar": 0,
96 | "bg": 1,
97 | "de": 2,
98 | "el": 3,
99 | "en": 4,
100 | "es": 5,
101 | "fr": 6,
102 | "hi": 7,
103 | "ru": 8,
104 | "sw": 9,
105 | "th": 10,
106 | "tr": 11,
107 | "ur": 12,
108 | "vi": 13,
109 | "zh": 14,
110 | },
111 | )
112 | return XLMModel(config=config)
113 |
114 |
115 | @pytest.fixture()
116 | def xlm_tokenizer():
117 | return XLMTokenizer.from_pretrained('xlm-mlm-tlm-xnli15-1024')
118 |
119 |
120 | def test_computation_target(
121 | tmp_path,
122 | xlm_model,
123 | xlm_tokenizer,
124 | xlm_model_dir,
125 | xlm_config_dict,
126 | train_config,
127 | data_config,
128 | big_atol,
129 | ):
130 | train_config['run']['output_dir'] = tmp_path
131 | train_config['data'] = data_config
132 | train_config['system'] = xlm_config_dict
133 |
134 | xlm_model.save_pretrained(tmp_path)
135 | xlm_tokenizer.save_pretrained(tmp_path)
136 | train_config['system']['model']['encoder']['model_name'] = str(tmp_path)
137 |
138 | check_computation(
139 | train_config,
140 | tmp_path,
141 | output_name=const.TARGET_TAGS,
142 | expected_avg_probs=0.410072,
143 | atol=big_atol,
144 | )
145 |
146 |
147 | if __name__ == '__main__': # pragma: no cover
148 | pytest.main([__file__]) # pragma: no cover
149 |
--------------------------------------------------------------------------------
/tests/test_xlmr.py:
--------------------------------------------------------------------------------
1 | # OpenKiwi: Open-Source Machine Translation Quality Estimation
2 | # Copyright (C) 2020 Unbabel
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published
6 | # by the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 | #
17 | import pytest
18 | import yaml
19 | from transformers import XLMRobertaConfig, XLMRobertaModel
20 |
21 | from conftest import check_computation
22 | from kiwi import constants as const
23 |
24 | xlmr_yaml = """
25 | class_name: XLMRoberta
26 |
27 | num_data_workers: 0
28 | batch_size:
29 | train: 16
30 | valid: 16
31 |
32 | model:
33 | encoder:
34 | model_name: None
35 | interleave_input: false
36 | freeze: false
37 | pooling: mixed
38 |
39 | decoder:
40 | hidden_size: 16
41 | dropout: 0.0
42 |
43 | outputs:
44 | word_level:
45 | target: true
46 | gaps: true
47 | source: true
48 | class_weights:
49 | target_tags:
50 | BAD: 3.0
51 | gap_tags:
52 | BAD: 5.0
53 | source_tags:
54 | BAD: 3.0
55 | sentence_level:
56 | hter: true
57 | use_distribution: true
58 | binary: false
59 | sentence_loss_weight: 1
60 |
61 | tlm_outputs:
62 | fine_tune: false
63 |
64 | optimizer:
65 | class_name: adamw
66 | learning_rate: 0.00001
67 | warmup_steps: 0.1
68 |
69 | data_processing:
70 | share_input_fields_encoders: true
71 |
72 | """
73 |
74 |
75 | @pytest.fixture
76 | def xlmr_config_dict():
77 | return yaml.unsafe_load(xlmr_yaml)
78 |
79 |
80 | @pytest.fixture(scope='session')
81 | def xlmr_model_dir(model_dir):
82 | return model_dir.joinpath('xlmr/')
83 |
84 |
85 | @pytest.fixture(scope='function')
86 | def xlmr_model():
87 | config = XLMRobertaConfig(
88 | vocab_size=251000,
89 | hidden_size=32,
90 | num_hidden_layers=5,
91 | num_attention_heads=4,
92 | intermediate_size=37,
93 | hidden_act='gelu',
94 | hidden_dropout_prob=0.1,
95 | attention_probs_dropout_prob=0.1,
96 | max_position_embeddings=256,
97 | type_vocab_size=2,
98 | is_decoder=False,
99 | initializer_range=0.02,
100 | )
101 | return XLMRobertaModel(config=config)
102 |
103 |
104 | def test_computation_target(
105 | tmp_path,
106 | xlmr_model,
107 | xlmr_model_dir,
108 | xlmr_config_dict,
109 | train_config,
110 | data_config,
111 | big_atol,
112 | ):
113 | train_config['run']['output_dir'] = tmp_path
114 | train_config['data'] = data_config
115 | train_config['system'] = xlmr_config_dict
116 |
117 | xlmr_model.save_pretrained(tmp_path)
118 | train_config['system']['model']['encoder']['model_name'] = str(tmp_path)
119 |
120 | # When using `adamw` optimizer and the `optimizer.training_steps` are not set:
121 | with pytest.raises(ValueError):
122 | check_computation(
123 | train_config,
124 | tmp_path,
125 | output_name=const.TARGET_TAGS,
126 | expected_avg_probs=0.383413,
127 | atol=big_atol,
128 | )
129 |
130 | # Now training will run:
131 | train_config['system']['optimizer']['training_steps'] = 10
132 | check_computation(
133 | train_config,
134 | tmp_path,
135 | output_name=const.TARGET_TAGS,
136 | expected_avg_probs=0.383413,
137 | atol=big_atol,
138 | )
139 |
140 |
141 | if __name__ == '__main__': # pragma: no cover
142 | pytest.main([__file__]) # pragma: no cover
143 |
--------------------------------------------------------------------------------
/tests/toy-data/WMT18/word_level/en_de.nmt/dev.hter:
--------------------------------------------------------------------------------
1 | 0.322581
2 | 0.000000
3 | 0.064516
4 | 0.263158
5 | 0.000000
6 | 0.312500
7 | 0.375000
8 | 0.727273
9 | 0.161290
10 | 0.000000
11 | 0.115385
12 | 0.192308
13 | 0.043478
14 | 0.153846
15 | 0.090909
16 | 0.187500
17 | 0.705882
18 | 0.000000
19 | 0.125000
20 | 0.388889
21 | 0.000000
22 | 0.045455
23 | 0.333333
24 | 0.722222
25 | 0.285714
26 | 0.620690
27 | 0.071429
28 | 0.000000
29 | 0.217391
30 | 0.000000
31 | 0.125000
32 | 0.000000
33 | 0.166667
34 | 0.000000
35 | 0.368421
36 | 0.076923
37 | 0.181818
38 | 0.000000
39 | 0.133333
40 | 0.052632
41 | 0.000000
42 | 0.600000
43 | 0.642857
44 | 0.125000
45 | 0.071429
46 | 0.083333
47 | 0.071429
48 | 0.000000
49 | 0.476190
50 | 0.000000
51 |
--------------------------------------------------------------------------------
/tests/toy-data/WMT18/word_level/en_de.nmt/dev.src:
--------------------------------------------------------------------------------
1 | to add or remove pixels when resizing so the image retains approximately the same appearance at a different size , select Resample Image .
2 | to update all assignments in the current document , choose Update All Assignments from the Assignments panel menu .
3 | in the Options tab , click the Custom button and enter lower values for Error Correction Level and Y / X Ratio .
4 | for example , you could create a document containing a car that moves across the Stage .
5 | in the New From Template dialog box , locate and select a template , and click New .
6 | make sure that you obtained the security settings file from a source that you trust .
7 | makes a rectangular selection ( or a square , when used with the Shift key ) .
8 | drag diagonally from the corner where you want the graph to begin to the opposite corner .
9 | enter a value from -100 % to 100 % to specify the percentage by which to decrease or increase the color or the spot-color tint .
10 | you can enable the Contribute publishing server using this dialog box .
11 | you can add any web page - not just pages in websites or entries in blogs that you 're connected to - to your bookmarks list .
12 | use the Export commands on the File menu to export all or part of an InDesign document to other formats .
13 | if appropriate , click Browse to navigate to the location in which you want the downloads to be placed .
14 | for Point / Pica Size , choose from the following options :
15 | to check the character code currently selected , select Shift JIS , JIS , Kuten , or Unicode , and display the code system .
16 | a digital signature , like a conventional handwritten signature , identifies the person signing a document .
17 | this option converts a complete 360 x 180 degree spherical panorama to a 3D layer .
18 | specifies that CSS styles appear in the Style menu .
19 | for example , suppose you want to update the content of a formatting table in a monthly magazine .
20 | in fact , most of the features you see in InDesign are provided by plug ‑ ins .
21 | the default is ^ t , which tells InDesign to insert a tab .
22 | select the new keyframe and drag one of the Learning Interaction movie clips from the Library panel to the Stage .
23 | sets the type of currency , such as Euros , Dollars , or Yen .
24 | enter Please indicate your level of satisfaction for the text parameter .
25 | substitutes the standard glyph with the jp78 ‑ variant glyph .
26 | resampling adds pixels to or subtracts pixels from a resized bitmap to match the appearance of the original bitmap as closely as possible .
27 | however , this feature can also be used to add malicious data into a PDF .
28 | to change the text size , click the Decrease Text Size button or the Increase Text Size button .
29 | please remember that existing artwork or images that you may want to include in your project may be protected under copyright law .
30 | when the Identity Setup dialog box appears , enter the appropriate information about yourself , and click Complete .
31 | the Properties toolbar is different in that it doesn 't contain tools and can 't be customized to hide options .
32 | for this reason , Adobe ® Flash ® Player includes a set of security rules and controls to safeguard the user , website owner , and content developer .
33 | does not have a menu bar .
34 | click the Start button and choose Settings > Printers And Faxes .
35 | it can also represent a spatial vector in physics , which has a direction and a magnitude .
36 | switch between the snapshots to find the settings you like best .
37 | then drag to adjust .
38 | choose Always from the Maximize PSD and PSB File Compatibility menu .
39 | signatures that certify an Adobe ® PDF are called certifying signatures .
40 | select an option under State , and then specify a label or icon option :
41 | to create a single-page document , choose File > Create PDF > From File .
42 | the file contains two parameters , monthNames and dayNames .
43 | contains information about the draft , including the blog title , the blog post title , and the associated tags .
44 | whether the profile is included or not is determined by the Profile Inclusion Policy .
45 | removing certain elements can seriously affect the functionality of the PDF .
46 | after the sound data starts loading , this code calls the snd.play ( ) method and stores the resulting SoundChannel object in the channel variable .
47 | after a web page completes its workflow or you have completed editing a blog , you can publish it to your website or blog from Contribute .
48 | select Window > Other Panels > Strings , and click Import XML .
49 | it performs as expected on paths that are oval , square , rectangular or otherwise irregularly shaped .
50 | press Enter or Return to begin a new paragraph .
51 |
--------------------------------------------------------------------------------
/tests/toy-data/WMT18/word_level/en_de.nmt/dev.src-mt.alignments:
--------------------------------------------------------------------------------
1 | 0-8 1-14 2-15 3-17 4-13 5-9 6-8 6-10 7-19 8-11 9-12 10-13 11-22 12-20 13-23 17-28 18-24 18-29 19-18 19-25 20-0 21-4 22-21 23-30
2 | 0-0 0-1 0-8 1-7 2-2 2-3 3-4 4-4 5-4 6-5 7-6 8-9 9-10 9-11 9-12 10-21 11-19 12-15 14-16 15-18 15-19 16-20 17-13 18-22
3 | 1-3 2-5 2-6 3-4 4-5 5-0 5-8 6-9 7-12 8-10 8-11 8-13 9-14 10-15 10-16 11-15 12-17 13-18 14-19 15-20 15-21 16-21 17-22 18-24 19-25 20-26 21-28 22-29
4 | 0-2 1-2 2-2 3-0 3-1 4-1 5-5 6-3 7-4 8-6 8-7 9-8 10-9 12-10 13-10 14-11 15-12 16-14
5 | 0-2 1-1 2-5 3-4 3-6 3-8 4-5 4-7 5-3 6-3 9-12 10-0 10-11 11-9 12-10 14-12 15-13 15-15 16-16 16-17 16-19 17-18
6 | 0-0 1-2 2-3 2-4 3-5 4-15 5-6 5-9 6-8 6-10 7-10 8-7 9-11 10-12 11-14 14-15 15-16
7 | 0-0 1-1 2-2 3-3 4-4 5-5 6-6 7-7 8-8 9-9 10-14 10-15 11-11 12-12 13-13 14-13 15-16 16-17
8 | 0-0 0-1 1-1 3-2 3-3 5-6 5-7 5-8 8-9 8-10 9-11 11-16 12-17 13-13 14-14 15-15 16-18
9 | 0-0 0-1 1-2 2-3 3-4 4-5 5-6 7-8 8-9 9-11 9-12 10-15 11-13 12-14 14-16 15-17 16-24 16-25 17-21 18-28 19-18 19-19 19-22 20-20 20-23 21-26 23-27 24-28 25-29
10 | 0-0 1-1 2-7 3-2 4-3 5-3 6-4 7-4 8-5 9-6 10-6 11-8
11 | 0-0 1-1 2-5 3-2 3-3 4-4 5-4 7-7 8-8 9-10 10-11 11-12 12-13 13-14 14-15 15-16 16-17 17-20 18-19 18-22 19-18 19-21 20-24 21-23 22-24 23-25 24-26 25-26 26-27
12 | 0-0 0-1 1-2 2-5 2-6 3-3 5-8 6-7 6-8 6-11 7-9 7-10 7-13 8-12 8-14 9-26 11-17 12-16 12-19 13-19 14-18 14-20 15-21 17-25 18-23 19-24 20-27
13 | 0-2 1-2 2-2 2-6 3-0 3-1 3-3 4-5 6-13 7-8 7-9 7-12 8-10 9-11 10-15 11-14 11-16 12-14 13-21 14-17 15-17 15-18 16-21 17-20 18-19 19-22
14 | 0-2 1-4 2-5 4-3 5-6 6-7 8-9 9-10 10-11 11-12
15 | 0-0 0-5 1-6 2-1 3-4 4-4 5-2 6-3 7-7 8-8 8-9 9-12 10-15 11-16 12-15 13-16 14-18 15-19 16-21 17-23 19-25 20-26 21-27 21-28 22-29 23-28 23-30 24-31
16 | 0-0 1-1 2-2 3-3 4-4 5-5 6-6 7-7 8-8 9-9 10-10 11-11 12-11 13-12 14-13 15-14 16-15
17 | 0-0 0-1 1-2 2-10 3-4 4-3 5-3 6-4 7-5 8-6 9-6 10-7 12-8 13-9 14-9 15-11
18 | 0-1 0-2 1-3 1-4 2-5 3-5 4-11 4-12 5-6 6-8 7-8 7-9 7-10 8-7 9-13
19 | 0-0 1-0 2-1 3-0 4-2 5-3 6-3 7-13 8-4 9-5 10-5 11-6 11-10 12-9 13-8 14-9 14-11 15-10 16-12 17-12 18-14
20 | 3-0 3-1 4-2 5-2 6-6 8-5 9-3 10-4 11-7 12-10 13-8 14-9 15-10 16-10 17-11
21 | 0-0 1-1 2-2 3-3 4-4 5-5 6-6 7-8 7-9 8-7 9-10 10-13 11-11 11-12 12-12 13-14
22 | 0-0 0-1 1-2 2-3 3-4 4-6 5-7 5-8 6-6 7-9 7-10 8-10 9-9 10-11 11-12 12-13 13-12 14-13 15-16 16-14 16-15 16-17 17-18 18-19 19-20 20-21
23 | 0-0 1-1 2-2 3-3 4-4 5-6 6-7 6-8 7-7 8-7 9-8 10-9 11-10 12-11 12-12 13-12 14-13
24 | 0-0 1-1 2-3 3-8 4-5 5-6 6-7 7-2 8-9 8-10 9-11 10-4 11-13
25 | 0-0 1-1 2-1 3-1 4-2 4-3 5-4 6-3 7-4 8-5 9-5 10-6
26 | 0-0 0-1 1-2 1-4 2-3 3-4 4-5 5-6 6-7 7-8 8-9 9-11 10-12 11-13 11-14 11-15 12-16 12-17 12-18 13-18 14-19 15-20 16-20 17-21 18-22 19-25 20-23 20-24 21-25 22-26 23-28
27 | 0-5 2-1 3-0 3-2 4-3 5-6 10-7 11-8 12-9 13-10 14-11 14-12 15-13
28 | 0-0 0-1 0-5 1-4 2-2 3-3 4-3 5-6 6-7 6-8 6-9 7-10 8-12 8-13 8-14 9-13 10-13 11-11 11-15 12-16 13-18 14-19 14-20 14-21 14-22 15-21 16-21 17-19 17-24 18-23
29 | 0-0 0-1 1-0 2-2 2-3 3-4 4-4 5-6 6-5 6-7 7-8 7-9 8-10 9-11 10-16 12-15 13-12 14-13 15-14 16-18 17-18 18-20 18-21 19-19 20-19 21-19 22-22
30 | 0-0 1-1 2-3 2-4 3-5 4-2 5-2 5-6 6-7 6-8 7-9 8-10 8-11 9-12 10-13 11-14 12-15 15-16 16-17 16-18 16-19 17-20 17-21 17-23 18-22
31 | 0-0 1-3 2-1 3-5 4-5 5-6 6-7 6-8 6-9 7-10 8-11 9-11 10-12 10-13 11-12 12-14 14-15 15-17 15-18 16-16 17-19 17-20 18-22 19-21 20-23
32 | 0-2 1-1 2-0 2-2 4-4 5-5 6-6 7-7 8-8 9-3 10-9 11-10 12-11 13-12 14-15 15-13 16-14 18-19 19-18 19-21 20-19 21-20 22-27 25-23 26-28 27-25 28-30
33 | 0-0 0-1 1-1 2-1 4-2 5-2 6-4
34 | 0-0 0-1 0-2 1-2 2-4 3-3 3-5 4-6 5-7 5-8 5-9 6-10 7-11 7-12 7-13 8-13 8-14 9-15 10-16 11-17
35 | 0-0 1-1 2-2 3-3 4-3 5-4 6-5 7-6 8-8 8-9 9-10 10-11 11-17 12-12 13-13 14-14 15-15 16-16 17-18
36 | 0-0 1-2 2-1 2-3 3-4 4-5 4-6 5-11 6-7 7-8 7-9 8-8 9-10 10-11 11-12
37 | 0-1 0-2 1-0 1-4 2-3 2-5 2-6 3-7 3-8 3-9 4-10
38 | 0-0 0-1 1-15 2-6 3-2 4-7 5-7 6-8 6-9 8-11 9-5 10-3 10-17 11-16
39 | 0-0 1-1 1-2 2-3 3-3 4-4 5-5 6-6 7-8 7-9 8-10 8-13 9-11 9-12 10-12 11-14
40 | 0-0 0-1 1-6 2-7 3-2 4-3 4-4 4-5 6-9 7-12 8-10 8-11 9-13 10-14 11-15 12-15 12-16 14-18
41 | 0-0 1-1 2-2 3-3 4-4 6-9 7-8 8-7 8-9 8-10 8-11 9-13 10-12 11-14 11-15 11-16 12-17 13-17 13-18 14-19
42 | 0-0 1-1 2-2 3-3 4-4 5-5 6-9 7-9 8-11 9-12
43 | 0-1 1-2 2-2 2-3 3-3 4-4 5-5 6-6 7-8 9-9 10-10 11-12 12-13 13-12 14-11 14-13 15-10 16-14 17-15 18-16 19-17 20-18
44 | 0-0 1-1 2-2 3-4 4-3 5-5 6-6 7-8 8-9 9-9 10-10 11-11 11-12 11-13 12-14 13-15 14-16
45 | 0-0 0-1 1-2 2-3 3-4 4-4 4-5 5-6 5-7 5-12 6-8 7-9 8-10 9-10 9-11 10-11 11-13
46 | 0-0 0-1 1-1 1-3 2-4 3-3 4-4 4-5 4-6 5-2 7-7 8-8 9-14 10-9 11-10 12-11 13-12 14-13 15-15 16-22 17-16 18-17 19-17 20-18 21-19 22-20 23-21 24-22 25-23
47 | 0-0 1-1 2-2 3-2 4-5 5-5 6-4 7-6 9-12 10-11 11-8 12-9 13-10 14-13 15-15 16-14 17-25 18-16 20-18 20-21 21-19 22-20 23-22 24-23 25-24 26-26
48 | 0-0 0-1 1-2 1-3 2-4 2-5 2-6 3-7 4-8 5-9 5-10 5-11 6-11 6-12 7-13 8-14 9-15 9-16 9-17 10-18 10-20 11-19 12-21
49 | 0-0 1-1 2-2 3-3 4-4 5-5 6-7 6-8 7-17 8-9 9-10 10-11 11-12 12-13 13-14 14-15 15-16 16-17 17-18
50 | 0-0 0-1 1-2 1-3 3-4 4-5 4-9 5-10 6-6 7-7 8-8 9-11
51 |
--------------------------------------------------------------------------------
/tests/toy-data/WMT18/word_level/en_de.nmt/dev.src_tags:
--------------------------------------------------------------------------------
1 | OK OK OK OK OK OK OK OK OK OK BAD OK BAD OK BAD BAD BAD BAD BAD OK OK OK OK OK
2 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
3 | OK OK OK OK OK OK OK OK OK OK OK BAD BAD OK OK OK OK OK OK OK OK OK OK
4 | OK OK OK OK OK OK OK OK OK BAD BAD BAD BAD OK OK OK OK
5 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
6 | OK OK OK OK BAD BAD OK OK OK BAD OK OK OK OK BAD OK
7 | OK OK OK OK OK OK OK BAD OK OK BAD OK BAD BAD BAD OK OK
8 | OK BAD BAD BAD BAD BAD OK OK OK BAD BAD BAD OK BAD BAD OK OK
9 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK BAD OK OK OK OK BAD BAD OK
10 | OK OK OK OK OK OK OK OK OK OK OK OK
11 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD BAD BAD OK
12 | OK OK OK OK OK OK OK OK OK OK BAD BAD BAD BAD OK OK OK OK OK OK OK
13 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK
14 | OK BAD OK OK OK BAD OK OK OK OK OK OK
15 | OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK BAD OK OK OK OK OK OK BAD OK OK
16 | OK OK OK OK OK OK OK BAD OK OK BAD OK OK OK OK OK OK
17 | BAD OK BAD OK BAD BAD BAD BAD BAD BAD BAD BAD OK OK OK OK
18 | OK OK OK OK OK OK OK OK OK OK
19 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK
20 | OK BAD BAD OK OK OK OK OK OK OK OK BAD OK OK OK OK BAD OK
21 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK
22 | OK OK OK OK OK OK BAD OK OK BAD OK OK OK OK OK OK OK OK OK OK OK
23 | OK OK OK OK BAD OK OK OK BAD OK BAD OK OK BAD BAD
24 | OK BAD BAD BAD BAD BAD BAD OK OK OK OK OK
25 | OK OK OK BAD BAD OK OK OK BAD BAD OK
26 | BAD BAD OK BAD BAD BAD BAD OK BAD BAD BAD BAD BAD OK OK OK OK OK OK OK OK OK OK OK
27 | OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK
28 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
29 | OK OK OK OK BAD OK OK BAD OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
30 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
31 | OK OK OK BAD BAD BAD BAD OK OK OK OK OK OK OK OK OK OK OK OK OK OK
32 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
33 | BAD OK OK OK OK OK OK
34 | OK OK OK OK OK OK OK OK OK OK OK OK
35 | BAD OK OK BAD BAD BAD OK BAD BAD OK OK OK OK OK OK OK OK OK
36 | OK OK OK BAD OK OK OK OK OK OK OK OK
37 | OK OK BAD BAD OK
38 | OK OK OK OK OK OK OK OK OK OK OK OK
39 | OK OK OK OK OK OK OK OK OK BAD OK OK
40 | OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK
41 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK
42 | OK OK OK OK OK OK BAD OK BAD OK
43 | OK OK OK OK OK OK BAD OK BAD BAD OK OK OK BAD OK OK OK BAD OK OK OK
44 | OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK
45 | OK OK OK OK OK OK OK BAD OK OK OK OK
46 | OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK OK OK OK OK OK OK BAD OK OK
47 | OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK OK BAD OK OK OK OK OK OK OK
48 | OK OK OK OK OK OK OK OK OK OK OK OK OK
49 | BAD BAD BAD OK OK OK OK OK BAD OK BAD OK BAD OK BAD BAD BAD OK
50 | OK OK OK OK OK OK OK OK OK OK
51 |
--------------------------------------------------------------------------------
/tests/toy-data/WMT18/word_level/en_de.nmt/train.hter:
--------------------------------------------------------------------------------
1 | 0.153846
2 | 0.230769
3 | 0.176471
4 | 0.000000
5 | 0.000000
6 | 0.083333
7 | 0.272727
8 | 0.000000
9 | 0.000000
10 | 0.142857
11 | 0.272727
12 | 0.600000
13 | 0.062500
14 | 0.636364
15 | 0.416667
16 | 0.041667
17 | 0.117647
18 | 0.000000
19 | 0.000000
20 | 0.111111
21 | 0.000000
22 | 0.000000
23 | 0.500000
24 | 0.041667
25 | 0.000000
26 | 0.052632
27 | 0.187500
28 | 0.705882
29 | 0.086957
30 | 0.117647
31 | 0.000000
32 | 0.000000
33 | 0.090909
34 | 0.090909
35 | 0.000000
36 | 0.080000
37 | 0.038462
38 | 0.000000
39 | 0.111111
40 | 0.038462
41 | 0.000000
42 | 0.000000
43 | 0.125000
44 | 0.150000
45 | 0.066667
46 | 0.047619
47 | 0.230769
48 | 0.333333
49 | 0.083333
50 | 0.000000
51 | 0.250000
52 | 0.052632
53 | 0.040000
54 | 0.000000
55 | 0.035714
56 | 0.066667
57 | 0.100000
58 | 0.733333
59 | 0.684211
60 | 0.812500
61 | 0.208333
62 | 0.090909
63 | 0.000000
64 | 0.294118
65 | 0.062500
66 | 0.285714
67 | 0.064516
68 | 0.000000
69 | 0.000000
70 | 0.285714
71 | 0.333333
72 | 0.739130
73 | 0.045455
74 | 0.250000
75 | 0.866667
76 | 0.045455
77 | 0.333333
78 | 0.080000
79 | 0.058824
80 | 0.083333
81 | 0.416667
82 | 0.000000
83 | 0.114286
84 | 0.000000
85 | 0.391304
86 | 0.000000
87 | 0.047619
88 | 0.043478
89 | 0.000000
90 | 0.000000
91 | 0.071429
92 | 0.000000
93 | 0.294118
94 | 0.038462
95 | 0.000000
96 | 0.200000
97 | 0.000000
98 | 0.100000
99 | 0.312500
100 | 0.210526
101 |
--------------------------------------------------------------------------------
/tests/toy-data/models/bert/vocab.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [UNK]
3 | [CLS]
4 | [SEP]
5 | [MASK]
6 |
7 |
8 | .
9 | ,
10 | the
11 | "
12 | sie
13 | die
14 | ##en
15 | to
16 | in
17 | -
18 | der
19 | ##s
20 | und
21 | and
22 | ##n
23 | )
24 | auf
25 | an
26 | (
27 | aus
28 | ##e
29 | you
30 | text
31 | a
32 | den
33 | of
34 | ##t
35 | or
36 | c
37 | for
38 | das
39 | mit
40 | im
41 | wird
42 | wa
43 | wenn
44 | zu
45 | option
46 | ze
47 | oder
48 | >
49 | um
50 | kl
51 | ##lick
52 | s
53 | ##er
54 | ein
55 | ##ick
56 | von
57 | ##be
58 | :
59 | eine
60 | ##ien
61 | is
62 | far
63 | fur
64 | ##hlen
65 | des
66 | werden
67 | ##wahl
68 | be
69 | sp
70 | ##te
71 | date
72 | konnen
73 | ##feld
74 | can
75 | select
76 | ##d
77 | ##ge
78 | menu
79 | ##fuge
80 | hinzu
81 | method
82 | ist
83 | ##f
84 | ##far
85 | ##l
86 | color
87 | that
88 | if
89 | from
90 | ##ten
91 | bil
92 | ang
93 | seit
94 | ##ieren
95 | ##zei
96 | ge
97 | links
98 | with
99 | selected
100 | add
101 | g
102 | ##ben
103 | ##ichen
104 | ##lach
105 | mocht
106 | format
107 | ##p
108 |
--------------------------------------------------------------------------------
/tests/toy-data/models/nuqe.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Unbabel/OpenKiwi/07d7cfed880457d95daf189dd8282e5b02ac2954/tests/toy-data/models/nuqe.ckpt
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # tox (https://tox.readthedocs.io/) is a tool for running tests
2 | # in multiple virtualenvs. This configuration file will run the
3 | # test suite on all supported python versions. To use it, "pip install tox"
4 | # and then run "tox" from this directory.
5 |
6 | [tox]
7 | isolated_build = true
8 | skip_missing_interpreters = true
9 | parallel_show_output = true
10 | envlist = lint,py37,docs
11 |
12 | [testenv]
13 | whitelist_externals =
14 | poetry
15 | skip_install = true
16 | setenv =
17 | PYTHONHASHSEED=0
18 | PYTHONPATH={toxinidir}
19 | ;commands =
20 | ; poetry install -v
21 | ; {[testenv:test]commands}
22 |
23 | [testenv:test]
24 | skip_install = true
25 | usedevelop = true
26 | envdir = {toxworkdir}/py37
27 | commands =
28 | poetry install -v -E search -E mlflow
29 | poetry run coverage erase
30 | poetry run pytest --cov=kiwi --cov-report term --cov-report xml --cov-append {posargs:tests}
31 |
32 | [testenv:lint]
33 | skip_install = true
34 | usedevelop = true
35 | envdir = {toxworkdir}/py37
36 | commands =
37 | poetry install -v
38 | poetry run flake8 {posargs:kiwi}
39 | poetry run black --check {posargs:kiwi}
40 | poetry run isort --check-only --diff --recursive {posargs:kiwi}
41 |
42 | [testenv:py37]
43 | commands =
44 | poetry install -v
45 | {[testenv:test]commands}
46 |
47 | [testenv:docs]
48 | skip_install = true
49 | usedevelop = true
50 | ;envdir = {toxworkdir}/docs
51 | envdir = {toxworkdir}/py37
52 | changedir = {toxinidir}/docs
53 | commands =
54 | poetry install -v
55 | ; poetry run sphinx-apidoc -f -o source {toxinidir}/kiwi
56 | poetry run sphinx-build -b html -d ./.doctrees . ../public
57 |
58 |
59 | [testenv:gh-pages]
60 | skip_install = true
61 | usedevelop = true
62 | envdir = {toxworkdir}/py37
63 | commands =
64 | poetry install -v
65 | ; poetry run sphinx-apidoc -f -o docs/source {toxinidir}/kiwi
66 | poetry run sphinx-build -b html -d docs/.doctrees docs gh-pages
67 |
68 | # Other packages config
69 |
70 | [flake8]
71 | max_line_length = 88
72 | select = C,E,F,W,B,B950
73 | ignore = W503,E203
74 |
75 | [pytest]
76 | python_files =
77 | test_*.py
78 | *_test.py
79 | tests.py
80 | norecursedirs =
81 | .git
82 | .tox
83 | .env
84 | dist
85 | build
86 |
87 | [coverage:run]
88 | branch = true
89 | parallel = true
90 | omit =
91 | kiwi/__main__.py
92 |
93 | [coverage:report]
94 | exclude_lines =
95 | pragma: no cover
96 | if __name__ == .__main__.:
97 |
--------------------------------------------------------------------------------