├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── docs ├── Makefile ├── _static │ └── theme_overrides.css ├── command_line_tools.rst ├── conf.py ├── criterions.rst ├── data.rst ├── docutils.conf ├── getting_started.rst ├── index.rst ├── lr_scheduler.rst ├── make.bat ├── models.rst ├── modules.rst ├── optim.rst ├── overview.rst ├── requirements.txt ├── tasks.rst ├── tutorial_classifying_names.rst └── tutorial_simple_lstm.rst ├── eval_lm.py ├── examples ├── .gitignore ├── backtranslation │ └── README.md ├── conv_seq2seq │ └── README.md ├── cross_lingual_language_model │ └── README.md ├── language_model │ ├── README.md │ ├── conv_lm │ │ └── README.md │ ├── prepare-wikitext-103.sh │ └── transformer_lm │ │ └── README.md ├── pay_less_attention_paper │ └── README.md ├── scaling_nmt │ └── README.md ├── stories │ └── README.md ├── translation │ ├── README.md │ ├── prepare-iwslt14.sh │ ├── prepare-iwslt17-multilingual.sh │ ├── prepare-wmt14en2de.sh │ └── prepare-wmt14en2fr.sh ├── translation_moe │ ├── README.md │ └── score.py └── wav2vec │ └── README.md ├── fairseq.gif ├── fairseq ├── __init__.py ├── binarizer.py ├── bleu.py ├── checkpoint_utils.py ├── clib │ └── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── binary_cross_entropy.py │ ├── composite_loss.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ ├── label_smoothed_cross_entropy.py │ └── masked_lm_loss.py ├── data │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ └── raw_audio_dataset.py │ ├── backtranslation_dataset.py │ ├── block_pair_dataset.py │ ├── concat_dataset.py │ ├── data_utils.py │ ├── dictionary.py │ ├── fairseq_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── lm_context_window_dataset.py │ ├── masked_lm_dataset.py │ ├── masked_lm_dictionary.py │ ├── monolingual_dataset.py │ ├── multi_corpus_sampled_dataset.py │ ├── noising.py │ ├── round_robin_zip_datasets.py │ ├── token_block_dataset.py │ ├── transform_eos_dataset.py │ ├── transform_eos_lang_pair_dataset.py │ └── transforms │ │ ├── __init__.py │ │ ├── gpt2_bpe.py │ │ ├── moses_tokenizer.py │ │ ├── nltk_tokenizer.py │ │ ├── sentencepiece_bpe.py │ │ ├── space_tokenizer.py │ │ └── subword_nmt_bpe.py ├── distributed_utils.py ├── file_utils.py ├── hub_utils.py ├── legacy_distributed_data_parallel.py ├── meters.py ├── models │ ├── __init__.py │ ├── composite_encoder.py │ ├── distributed_fairseq_model.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_lm.py │ ├── fconv_self_att.py │ ├── lightconv.py │ ├── lightconv_lm.py │ ├── lstm.py │ ├── masked_lm.py │ ├── multilingual_transformer.py │ ├── transformer.py │ ├── transformer_from_pretrained_xlm.py │ ├── transformer_lm.py │ └── wav2vec.py ├── modules │ ├── __init__.py │ ├── adaptive_input.py │ ├── adaptive_softmax.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── downsampled_multihead_attention.py │ ├── dynamic_convolution.py │ ├── gelu.py │ ├── grad_multiply.py │ ├── highway.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightweight_convolution.py │ ├── linearized_convolution.py │ ├── logsumexp_moe.py │ ├── mean_pool_gating_network.py │ ├── multihead_attention.py │ ├── positional_embedding.py │ ├── scalar_bias.py │ ├── sinusoidal_positional_embedding.py │ ├── transformer_sentence_encoder.py │ ├── transformer_sentence_encoder_layer.py │ └── unfold.py ├── optim │ ├── __init__.py │ ├── adadelta.py │ ├── adafactor.py │ ├── adagrad.py │ ├── adam.py │ ├── bmuf.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── lamb.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── polynomial_decay_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ └── sgd.py ├── options.py ├── pdb.py ├── progress_bar.py ├── registry.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── audio_pretraining.py │ ├── cross_lingual_lm.py │ ├── fairseq_task.py │ ├── language_modeling.py │ ├── masked_lm.py │ ├── multilingual_translation.py │ ├── semisupervised_translation.py │ ├── translation.py │ ├── translation_from_pretrained_xlm.py │ └── translation_moe.py ├── tokenizer.py ├── trainer.py └── utils.py ├── fairseq_cli ├── __init__.py ├── eval_lm.py ├── generate.py ├── interactive.py ├── preprocess.py ├── score.py ├── setup.py └── train.py ├── fairseq_logo.png ├── generate.py ├── hubconf.py ├── interactive.py ├── preprocess.py ├── score.py ├── scripts ├── __init__.py ├── average_checkpoints.py ├── build_sym_alignment.py ├── compare_namespaces.py ├── compound_split_bleu.sh ├── convert_dictionary.lua ├── convert_model.lua ├── count_docs.py ├── read_binarized.py ├── rm_pt.py ├── sacrebleu_pregen.sh ├── shard_docs.py ├── split_train_valid_docs.py ├── spm_decode.py ├── spm_encode.py ├── spm_train.py ├── wav2vec_featurize.py └── wav2vec_manifest.py ├── setup.py ├── tests ├── __init__.py ├── test_average_checkpoints.py ├── test_backtranslation_dataset.py ├── test_binaries.py ├── test_character_token_embedder.py ├── test_concat_dataset.py ├── test_convtbc.py ├── test_dictionary.py ├── test_iterators.py ├── test_label_smoothing.py ├── test_memory_efficient_fp16.py ├── test_multi_corpus_sampled_dataset.py ├── test_noising.py ├── test_reproducibility.py ├── test_sequence_generator.py ├── test_sequence_scorer.py ├── test_token_block_dataset.py ├── test_train.py ├── test_utils.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBrains PyCharm IDE 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # macOS dir files 13 | .DS_Store 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # Checkpoints 35 | checkpoints 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | # Generated files 113 | fairseq/temporal_convolution_tbc 114 | 115 | # data 116 | data-bin/ 117 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct) so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch) 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## Coding Style 26 | We try to follow the PEP style guidelines and encourage you to as well. 27 | 28 | ## License 29 | By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed 30 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fairseq software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fairseq software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = fairseq 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | .wy-table-responsive table td kbd { 2 | white-space: nowrap; 3 | } 4 | .wy-table-responsive table td { 5 | white-space: normal !important; 6 | } 7 | .wy-table-responsive { 8 | overflow: visible !important; 9 | } 10 | -------------------------------------------------------------------------------- /docs/command_line_tools.rst: -------------------------------------------------------------------------------- 1 | .. _Command-line Tools: 2 | 3 | Command-line Tools 4 | ================== 5 | 6 | Fairseq provides several command-line tools for training and evaluating models: 7 | 8 | - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data 9 | - :ref:`fairseq-train`: Train a new model on one or multiple GPUs 10 | - :ref:`fairseq-generate`: Translate pre-processed data with a trained model 11 | - :ref:`fairseq-interactive`: Translate raw text with a trained model 12 | - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations 13 | - :ref:`fairseq-eval-lm`: Language model evaluation 14 | 15 | 16 | .. _fairseq-preprocess: 17 | 18 | fairseq-preprocess 19 | ~~~~~~~~~~~~~~~~~~ 20 | .. automodule:: preprocess 21 | 22 | .. argparse:: 23 | :module: fairseq.options 24 | :func: get_preprocessing_parser 25 | :prog: fairseq-preprocess 26 | 27 | 28 | .. _fairseq-train: 29 | 30 | fairseq-train 31 | ~~~~~~~~~~~~~ 32 | .. automodule:: train 33 | 34 | .. argparse:: 35 | :module: fairseq.options 36 | :func: get_training_parser 37 | :prog: fairseq-train 38 | 39 | 40 | .. _fairseq-generate: 41 | 42 | fairseq-generate 43 | ~~~~~~~~~~~~~~~~ 44 | .. automodule:: generate 45 | 46 | .. argparse:: 47 | :module: fairseq.options 48 | :func: get_generation_parser 49 | :prog: fairseq-generate 50 | 51 | 52 | .. _fairseq-interactive: 53 | 54 | fairseq-interactive 55 | ~~~~~~~~~~~~~~~~~~~ 56 | .. automodule:: interactive 57 | 58 | .. argparse:: 59 | :module: fairseq.options 60 | :func: get_interactive_generation_parser 61 | :prog: fairseq-interactive 62 | 63 | 64 | .. _fairseq-score: 65 | 66 | fairseq-score 67 | ~~~~~~~~~~~~~ 68 | .. automodule:: score 69 | 70 | .. argparse:: 71 | :module: fairseq_cli.score 72 | :func: get_parser 73 | :prog: fairseq-score 74 | 75 | 76 | .. _fairseq-eval-lm: 77 | 78 | fairseq-eval-lm 79 | ~~~~~~~~~~~~~~~ 80 | .. automodule:: eval_lm 81 | 82 | .. argparse:: 83 | :module: fairseq.options 84 | :func: get_eval_lm_parser 85 | :prog: fairseq-eval-lm 86 | -------------------------------------------------------------------------------- /docs/criterions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Criterions: 5 | 6 | Criterions 7 | ========== 8 | 9 | Criterions compute the loss function given the model and batch, roughly:: 10 | 11 | loss = criterion(model, batch) 12 | 13 | .. automodule:: fairseq.criterions 14 | :members: 15 | 16 | .. autoclass:: fairseq.criterions.FairseqCriterion 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss 21 | :members: 22 | :undoc-members: 23 | .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss 24 | :members: 25 | :undoc-members: 26 | .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion 27 | :members: 28 | :undoc-members: 29 | .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion 30 | :members: 31 | :undoc-members: 32 | -------------------------------------------------------------------------------- /docs/data.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.data 5 | 6 | Data Loading and Utilities 7 | ========================== 8 | 9 | .. _datasets: 10 | 11 | Datasets 12 | -------- 13 | 14 | **Datasets** define the data format and provide helpers for creating 15 | mini-batches. 16 | 17 | .. autoclass:: fairseq.data.FairseqDataset 18 | :members: 19 | .. autoclass:: fairseq.data.LanguagePairDataset 20 | :members: 21 | .. autoclass:: fairseq.data.MonolingualDataset 22 | :members: 23 | 24 | **Helper Datasets** 25 | 26 | These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and 27 | provide additional functionality: 28 | 29 | .. autoclass:: fairseq.data.BacktranslationDataset 30 | :members: 31 | .. autoclass:: fairseq.data.ConcatDataset 32 | :members: 33 | .. autoclass:: fairseq.data.RoundRobinZipDatasets 34 | :members: 35 | .. autoclass:: fairseq.data.TransformEosDataset 36 | :members: 37 | 38 | 39 | Dictionary 40 | ---------- 41 | 42 | .. autoclass:: fairseq.data.Dictionary 43 | :members: 44 | 45 | 46 | Iterators 47 | --------- 48 | 49 | .. autoclass:: fairseq.data.CountingIterator 50 | :members: 51 | .. autoclass:: fairseq.data.EpochBatchIterator 52 | :members: 53 | .. autoclass:: fairseq.data.GroupedIterator 54 | :members: 55 | .. autoclass:: fairseq.data.ShardedIterator 56 | :members: 57 | -------------------------------------------------------------------------------- /docs/docutils.conf: -------------------------------------------------------------------------------- 1 | [writers] 2 | option-limit=0 3 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fairseq documentation master file, created by 2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pytorch/fairseq 7 | 8 | 9 | fairseq documentation 10 | ===================== 11 | 12 | Fairseq is a sequence modeling toolkit written in `PyTorch 13 | `_ that allows researchers and developers to 14 | train custom models for translation, summarization, language modeling and other 15 | text generation tasks. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Getting Started 20 | 21 | getting_started 22 | command_line_tools 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Extending Fairseq 27 | 28 | overview 29 | tutorial_simple_lstm 30 | tutorial_classifying_names 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | :caption: Library Reference 35 | 36 | tasks 37 | models 38 | criterions 39 | optim 40 | lr_scheduler 41 | data 42 | modules 43 | 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`search` 50 | -------------------------------------------------------------------------------- /docs/lr_scheduler.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Learning Rate Schedulers: 5 | 6 | Learning Rate Schedulers 7 | ======================== 8 | 9 | Learning Rate Schedulers update the learning rate over the course of training. 10 | Learning rates can be updated after each update via :func:`step_update` or at 11 | epoch boundaries via :func:`step`. 12 | 13 | .. automodule:: fairseq.optim.lr_scheduler 14 | :members: 15 | 16 | .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule 21 | :members: 22 | :undoc-members: 23 | .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule 24 | :members: 25 | :undoc-members: 26 | .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule 27 | :members: 28 | :undoc-members: 29 | .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau 30 | :members: 31 | :undoc-members: 32 | .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule 33 | :members: 34 | :undoc-members: 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fairseq 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.models 5 | 6 | .. _Models: 7 | 8 | Models 9 | ====== 10 | 11 | A Model defines the neural network's ``forward()`` method and encapsulates all 12 | of the learnable parameters in the network. Each model also provides a set of 13 | named *architectures* that define the precise network configuration (e.g., 14 | embedding dimension, number of layers, etc.). 15 | 16 | Both the model type and architecture are selected via the ``--arch`` 17 | command-line argument. Once selected, a model may expose additional command-line 18 | arguments for further configuration. 19 | 20 | .. note:: 21 | 22 | All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends 23 | :class:`torch.nn.Module`. Thus any fairseq Model can be used as a 24 | stand-alone Module in other PyTorch code. 25 | 26 | 27 | Convolutional Neural Networks (CNN) 28 | ----------------------------------- 29 | 30 | .. module:: fairseq.models.fconv 31 | .. autoclass:: fairseq.models.fconv.FConvModel 32 | :members: 33 | .. autoclass:: fairseq.models.fconv.FConvEncoder 34 | :members: 35 | :undoc-members: 36 | .. autoclass:: fairseq.models.fconv.FConvDecoder 37 | :members: 38 | 39 | 40 | Long Short-Term Memory (LSTM) networks 41 | -------------------------------------- 42 | 43 | .. module:: fairseq.models.lstm 44 | .. autoclass:: fairseq.models.lstm.LSTMModel 45 | :members: 46 | .. autoclass:: fairseq.models.lstm.LSTMEncoder 47 | :members: 48 | .. autoclass:: fairseq.models.lstm.LSTMDecoder 49 | :members: 50 | 51 | 52 | Transformer (self-attention) networks 53 | ------------------------------------- 54 | 55 | .. module:: fairseq.models.transformer 56 | .. autoclass:: fairseq.models.transformer.TransformerModel 57 | :members: 58 | .. autoclass:: fairseq.models.transformer.TransformerEncoder 59 | :members: 60 | .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer 61 | :members: 62 | .. autoclass:: fairseq.models.transformer.TransformerDecoder 63 | :members: 64 | .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer 65 | :members: 66 | 67 | 68 | Adding new models 69 | ----------------- 70 | 71 | .. currentmodule:: fairseq.models 72 | .. autofunction:: fairseq.models.register_model 73 | .. autofunction:: fairseq.models.register_model_architecture 74 | .. autoclass:: fairseq.models.BaseFairseqModel 75 | :members: 76 | :undoc-members: 77 | .. autoclass:: fairseq.models.FairseqEncoderDecoderModel 78 | :members: 79 | :undoc-members: 80 | .. autoclass:: fairseq.models.FairseqEncoderModel 81 | :members: 82 | :undoc-members: 83 | .. autoclass:: fairseq.models.FairseqLanguageModel 84 | :members: 85 | :undoc-members: 86 | .. autoclass:: fairseq.models.FairseqMultiModel 87 | :members: 88 | :undoc-members: 89 | .. autoclass:: fairseq.models.FairseqEncoder 90 | :members: 91 | .. autoclass:: fairseq.models.CompositeEncoder 92 | :members: 93 | .. autoclass:: fairseq.models.FairseqDecoder 94 | :members: 95 | 96 | 97 | .. _Incremental decoding: 98 | 99 | Incremental decoding 100 | -------------------- 101 | 102 | .. autoclass:: fairseq.models.FairseqIncrementalDecoder 103 | :members: 104 | :undoc-members: 105 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may 5 | be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`. 6 | 7 | .. automodule:: fairseq.modules 8 | :members: 9 | :undoc-members: 10 | -------------------------------------------------------------------------------- /docs/optim.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _optimizers: 5 | 6 | Optimizers 7 | ========== 8 | 9 | Optimizers update the Model parameters based on the gradients. 10 | 11 | .. automodule:: fairseq.optim 12 | :members: 13 | 14 | .. autoclass:: fairseq.optim.FairseqOptimizer 15 | :members: 16 | :undoc-members: 17 | 18 | .. autoclass:: fairseq.optim.adadelta.Adadelta 19 | :members: 20 | :undoc-members: 21 | .. autoclass:: fairseq.optim.adagrad.Adagrad 22 | :members: 23 | :undoc-members: 24 | .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor 25 | :members: 26 | :undoc-members: 27 | .. autoclass:: fairseq.optim.adam.FairseqAdam 28 | :members: 29 | :undoc-members: 30 | .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer 31 | :members: 32 | :undoc-members: 33 | .. autoclass:: fairseq.optim.nag.FairseqNAG 34 | :members: 35 | :undoc-members: 36 | .. autoclass:: fairseq.optim.sgd.SGD 37 | :members: 38 | :undoc-members: 39 | -------------------------------------------------------------------------------- /docs/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | Fairseq can be extended through user-supplied `plug-ins 5 | `_. We support five kinds of 6 | plug-ins: 7 | 8 | - :ref:`Models` define the neural network architecture and encapsulate all of the 9 | learnable parameters. 10 | - :ref:`Criterions` compute the loss function given the model outputs and targets. 11 | - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | - :ref:`Optimizers` update the Model parameters based on the gradients. 14 | - :ref:`Learning Rate Schedulers` update the learning rate over the course of 15 | training. 16 | 17 | **Training Flow** 18 | 19 | Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``, 20 | fairseq implements the following high-level training flow:: 21 | 22 | for epoch in range(num_epochs): 23 | itr = task.get_batch_iterator(task.dataset('train')) 24 | for num_updates, batch in enumerate(itr): 25 | task.train_step(batch, model, criterion, optimizer) 26 | average_and_clip_gradients() 27 | optimizer.step() 28 | lr_scheduler.step_update(num_updates) 29 | lr_scheduler.step(epoch) 30 | 31 | where the default implementation for ``task.train_step`` is roughly:: 32 | 33 | def train_step(self, batch, model, criterion, optimizer): 34 | loss = criterion(model, batch) 35 | optimizer.backward(loss) 36 | return loss 37 | 38 | **Registering new plug-ins** 39 | 40 | New plug-ins are *registered* through a set of ``@register`` function 41 | decorators, for example:: 42 | 43 | @register_model('my_lstm') 44 | class MyLSTM(FairseqEncoderDecoderModel): 45 | (...) 46 | 47 | Once registered, new plug-ins can be used with the existing :ref:`Command-line 48 | Tools`. See the Tutorial sections for more detailed walkthroughs of how to add 49 | new plug-ins. 50 | 51 | **Loading plug-ins from another directory** 52 | 53 | New plug-ins can be defined in a custom module stored in the user system. In 54 | order to import the module, and make the plugin available to *fairseq*, the 55 | command line supports the ``--user-dir`` flag that can be used to specify a 56 | custom location for additional modules to load into *fairseq*. 57 | 58 | For example, assuming this directory tree:: 59 | 60 | /home/user/my-module/ 61 | └── __init__.py 62 | 63 | with ``__init__.py``:: 64 | 65 | from fairseq.models import register_model_architecture 66 | from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big 67 | 68 | @register_model_architecture('transformer', 'my_transformer') 69 | def transformer_mmt_big(args): 70 | transformer_vaswani_wmt_en_de_big(args) 71 | 72 | it is possible to invoke the :ref:`fairseq-train` script with the new architecture with:: 73 | 74 | fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation 75 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<2.0 2 | sphinx-argparse 3 | -------------------------------------------------------------------------------- /docs/tasks.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.tasks 5 | 6 | .. _Tasks: 7 | 8 | Tasks 9 | ===== 10 | 11 | Tasks store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | 14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a 15 | task may expose additional command-line arguments for further configuration. 16 | 17 | Example usage:: 18 | 19 | # setup the task (e.g., load dictionaries) 20 | task = fairseq.tasks.setup_task(args) 21 | 22 | # build model and criterion 23 | model = task.build_model(args) 24 | criterion = task.build_criterion(args) 25 | 26 | # load datasets 27 | task.load_dataset('train') 28 | task.load_dataset('valid') 29 | 30 | # iterate over mini-batches of data 31 | batch_itr = task.get_batch_iterator( 32 | task.dataset('train'), max_tokens=4096, 33 | ) 34 | for batch in batch_itr: 35 | # compute the loss 36 | loss, sample_size, logging_output = task.get_loss( 37 | model, criterion, batch, 38 | ) 39 | loss.backward() 40 | 41 | 42 | Translation 43 | ----------- 44 | 45 | .. autoclass:: fairseq.tasks.translation.TranslationTask 46 | 47 | .. _language modeling: 48 | 49 | Language Modeling 50 | ----------------- 51 | 52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask 53 | 54 | 55 | Adding new tasks 56 | ---------------- 57 | 58 | .. autofunction:: fairseq.tasks.register_task 59 | .. autoclass:: fairseq.tasks.FairseqTask 60 | :members: 61 | :undoc-members: 62 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | !*/*.sh 2 | !*/*.md 3 | -------------------------------------------------------------------------------- /examples/backtranslation/README.md: -------------------------------------------------------------------------------- 1 | # Understanding Back-Translation at Scale (Edunov et al., 2018) 2 | 3 | This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381). 4 | 5 | ## Pre-trained models 6 | 7 | Description | Dataset | Model | Test set(s) 8 | ---|---|---|--- 9 | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) | See NOTE in the archive 10 | 11 | ## Example usage 12 | 13 | Interactive generation from the full ensemble via PyTorch Hub: 14 | ``` 15 | >>> import torch 16 | >>> en2de_ensemble = torch.hub.load( 17 | ... 'pytorch/fairseq', 18 | ... 'transformer', 19 | ... model_name_or_path='transformer.wmt18.en-de', 20 | ... checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt', 21 | ... data_name_or_path='.', 22 | ... tokenizer='moses', 23 | ... aggressive_dash_splits=True, 24 | ... bpe='subword_nmt', 25 | ... ) 26 | >>> len(en2de_ensemble.models) 27 | 5 28 | >>> print(en2de_ensemble.generate('Hello world!')) 29 | Hallo Welt! 30 | ``` 31 | 32 | ## Citation 33 | ```bibtex 34 | @inproceedings{edunov2018backtranslation, 35 | title = {Understanding Back-Translation at Scale}, 36 | author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David}, 37 | booktitle = {Conference of the Association for Computational Linguistics (ACL)}, 38 | year = 2018, 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /examples/conv_seq2seq/README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Sequence to Sequence Learning (Gehring et al., 2017) 2 | 3 | ## Pre-trained models 4 | 5 | Description | Dataset | Model | Test set(s) 6 | ---|---|---|--- 7 | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) 8 | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) 9 | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) 10 | 11 | ## Example usage 12 | 13 | See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and 14 | WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures. 15 | 16 | ## Citation 17 | 18 | ```bibtex 19 | @inproceedings{gehring2017convs2s, 20 | title = {Convolutional Sequence to Sequence Learning}, 21 | author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, 22 | booktitle = {Proc. of ICML}, 23 | year = 2017, 24 | } 25 | ``` 26 | -------------------------------------------------------------------------------- /examples/cross_lingual_language_model/README.md: -------------------------------------------------------------------------------- 1 | # Cross-Lingual Language Model Pre-training 2 | 3 | Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above. 4 | 5 | ## Downloading and Tokenizing Monolingual Data 6 | 7 | Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data). 8 | 9 | Let's assume the following for the code snippets in later sections to work 10 | - Processed data is in the folder: monolingual_data/processed 11 | - Each language has 3 files for train, test and validation. For example we have the following files for English: 12 | train.en, valid.en 13 | - We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr) 14 | - The vocabulary file is monolingual_data/processed/vocab_mlm 15 | 16 | 17 | ## Fairseq Pre-processing and Binarization 18 | 19 | Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task 20 | 21 | ```bash 22 | # Ensure the output directory exists 23 | DATA_DIR=monolingual_data/fairseq_processed 24 | mkdir -p "$DATA_DIR" 25 | 26 | for lg in ar de en hi fr 27 | do 28 | 29 | fairseq-preprocess \ 30 | --task cross_lingual_lm \ 31 | --srcdict monolingual_data/processed/vocab_mlm \ 32 | --only-source \ 33 | --trainpref monolingual_data/processed/train \ 34 | --validpref monolingual_data/processed/valid \ 35 | --testpref monolingual_data/processed/test \ 36 | --destdir monolingual_data/fairseq_processed \ 37 | --workers 20 \ 38 | --source-lang $lg 39 | 40 | # Since we only have a source language, the output file has a None for the 41 | # target language. Remove this 42 | 43 | for stage in train test valid 44 | 45 | sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin" 46 | sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx" 47 | 48 | done 49 | 50 | done 51 | ``` 52 | 53 | ## Train a Cross-lingual Language Model similar to the XLM MLM model 54 | 55 | Use the following command to train the model on 5 languages. 56 | 57 | ``` 58 | fairseq-train \ 59 | --task cross_lingual_lm monolingual_data/fairseq_processed \ 60 | --save-dir checkpoints/mlm \ 61 | --max-update 2400000 --save-interval 1 --no-epoch-checkpoints \ 62 | --arch xlm_base \ 63 | --optimizer adam --lr-scheduler reduce_lr_on_plateau \ 64 | --lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \ 65 | --dropout 0.1 \ 66 | --criterion masked_lm_loss \ 67 | --max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \ 68 | --dataset-impl lazy --seed 0 \ 69 | --masked-lm-only \ 70 | --monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \ 71 | --ddp-backend=no_c10d 72 | ``` 73 | 74 | Some Notes: 75 | - Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning. 76 | - The Evaluation workflow for computing MLM Perplexity on test data is in progress. 77 | - Finetuning this model on a downstream task is something which is not currently available. 78 | -------------------------------------------------------------------------------- /examples/language_model/conv_lm/README.md: -------------------------------------------------------------------------------- 1 | # Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017) 2 | 3 | ## Example usage 4 | 5 | See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103 6 | using the `fconv_lm_dauphin_wikitext103` model architecture. 7 | 8 | ## Citation 9 | 10 | ```bibtex 11 | @inproceedings{dauphin2017language, 12 | title={Language Modeling with Gated Convolutional Networks}, 13 | author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David}, 14 | booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70}, 15 | pages={933--941}, 16 | year={2017}, 17 | organization={JMLR} 18 | } 19 | ``` 20 | -------------------------------------------------------------------------------- /examples/language_model/prepare-wikitext-103.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | URLS=( 5 | "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" 6 | ) 7 | FILES=( 8 | "wikitext-103-v1.zip" 9 | ) 10 | 11 | for ((i=0;i<${#URLS[@]};++i)); do 12 | file=${FILES[i]} 13 | if [ -f $file ]; then 14 | echo "$file already exists, skipping download" 15 | else 16 | url=${URLS[i]} 17 | wget "$url" 18 | if [ -f $file ]; then 19 | echo "$url successfully downloaded." 20 | else 21 | echo "$url not successfully downloaded." 22 | exit -1 23 | fi 24 | if [ ${file: -4} == ".tgz" ]; then 25 | tar zxvf $file 26 | elif [ ${file: -4} == ".tar" ]; then 27 | tar xvf $file 28 | elif [ ${file: -4} == ".zip" ]; then 29 | unzip $file 30 | fi 31 | fi 32 | done 33 | cd .. 34 | -------------------------------------------------------------------------------- /examples/language_model/transformer_lm/README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Input Representations for Neural Language Modeling (Baevski and Auli; 2018) 2 | 3 | ## Pre-trained models 4 | 5 | Description | Parameters | Dataset | Model and Test set(s) 6 | ---|---:|---|--- 7 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2) 8 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2) 9 | 10 | ## Example usage 11 | 12 | See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103 13 | using the `transformer_lm_wiki103` model architecture. 14 | 15 | ## Citation 16 | 17 | ```bibtex 18 | @inproceedings{ 19 | baevski2018adaptive, 20 | title={Adaptive Input Representations for Neural Language Modeling}, 21 | author={Alexei Baevski and Michael Auli}, 22 | booktitle={International Conference on Learning Representations}, 23 | year={2019}, 24 | url={https://openreview.net/forum?id=ByxZX20qFQ}, 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/scaling_nmt/README.md: -------------------------------------------------------------------------------- 1 | # Scaling Neural Machine Translation (Ott et al., 2018) 2 | 3 | This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187). 4 | 5 | ## Pre-trained models 6 | 7 | Description | Dataset | Model | Test set(s) 8 | ---|---|---|--- 9 | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) 10 | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) 11 | 12 | ## Training a new model on WMT'16 En-De 13 | 14 | Please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8). 15 | Then: 16 | 17 | 1. Extract the WMT'16 En-De data: 18 | ``` 19 | $ TEXT=wmt16_en_de_bpe32k 20 | $ mkdir $TEXT 21 | $ tar -xzvf wmt16_en_de.tar.gz -C $TEXT 22 | ``` 23 | 24 | 2. Preprocess the dataset with a joined dictionary: 25 | ``` 26 | $ fairseq-preprocess --source-lang en --target-lang de \ 27 | --trainpref $TEXT/train.tok.clean.bpe.32000 \ 28 | --validpref $TEXT/newstest2013.tok.bpe.32000 \ 29 | --testpref $TEXT/newstest2014.tok.bpe.32000 \ 30 | --destdir data-bin/wmt16_en_de_bpe32k \ 31 | --nwordssrc 32768 --nwordstgt 32768 \ 32 | --joined-dictionary 33 | ``` 34 | 35 | 3. Train a model: 36 | ``` 37 | $ fairseq-train data-bin/wmt16_en_de_bpe32k \ 38 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ 39 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 40 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ 41 | --lr 0.0005 --min-lr 1e-09 \ 42 | --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 43 | --max-tokens 3584 \ 44 | --fp16 45 | ``` 46 | 47 | Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU. 48 | 49 | If you want to train the above model with big batches (assuming your machine has 8 GPUs): 50 | - add `--update-freq 16` to simulate training on 8*16=128 GPUs 51 | - increase the learning rate; 0.001 works well for big batches 52 | 53 | ## Citation 54 | 55 | ```bibtex 56 | @inproceedings{ott2018scaling, 57 | title = {Scaling Neural Machine Translation}, 58 | author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael}, 59 | booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)}, 60 | year = 2018, 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /examples/translation/prepare-iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 4 | 5 | echo 'Cloning Moses github repository (for tokenization scripts)...' 6 | git clone https://github.com/moses-smt/mosesdecoder.git 7 | 8 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 9 | git clone https://github.com/rsennrich/subword-nmt.git 10 | 11 | SCRIPTS=mosesdecoder/scripts 12 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 13 | LC=$SCRIPTS/tokenizer/lowercase.perl 14 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=10000 17 | 18 | URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" 19 | GZ=de-en.tgz 20 | 21 | if [ ! -d "$SCRIPTS" ]; then 22 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 23 | exit 24 | fi 25 | 26 | src=de 27 | tgt=en 28 | lang=de-en 29 | prep=iwslt14.tokenized.de-en 30 | tmp=$prep/tmp 31 | orig=orig 32 | 33 | mkdir -p $orig $tmp $prep 34 | 35 | echo "Downloading data from ${URL}..." 36 | cd $orig 37 | wget "$URL" 38 | 39 | if [ -f $GZ ]; then 40 | echo "Data successfully downloaded." 41 | else 42 | echo "Data not successfully downloaded." 43 | exit 44 | fi 45 | 46 | tar zxvf $GZ 47 | cd .. 48 | 49 | echo "pre-processing train data..." 50 | for l in $src $tgt; do 51 | f=train.tags.$lang.$l 52 | tok=train.tags.$lang.tok.$l 53 | 54 | cat $orig/$lang/$f | \ 55 | grep -v '' | \ 56 | grep -v '' | \ 57 | grep -v '' | \ 58 | sed -e 's///g' | \ 59 | sed -e 's/<\/title>//g' | \ 60 | sed -e 's/<description>//g' | \ 61 | sed -e 's/<\/description>//g' | \ 62 | perl $TOKENIZER -threads 8 -l $l > $tmp/$tok 63 | echo "" 64 | done 65 | perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175 66 | for l in $src $tgt; do 67 | perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l 68 | done 69 | 70 | echo "pre-processing valid/test data..." 71 | for l in $src $tgt; do 72 | for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do 73 | fname=${o##*/} 74 | f=$tmp/${fname%.*} 75 | echo $o $f 76 | grep '<seg id' $o | \ 77 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 78 | sed -e 's/\s*<\/seg>\s*//g' | \ 79 | sed -e "s/\’/\'/g" | \ 80 | perl $TOKENIZER -threads 8 -l $l | \ 81 | perl $LC > $f 82 | echo "" 83 | done 84 | done 85 | 86 | 87 | echo "creating train, valid, test..." 88 | for l in $src $tgt; do 89 | awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l 90 | awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l 91 | 92 | cat $tmp/IWSLT14.TED.dev2010.de-en.$l \ 93 | $tmp/IWSLT14.TEDX.dev2012.de-en.$l \ 94 | $tmp/IWSLT14.TED.tst2010.de-en.$l \ 95 | $tmp/IWSLT14.TED.tst2011.de-en.$l \ 96 | $tmp/IWSLT14.TED.tst2012.de-en.$l \ 97 | > $tmp/test.$l 98 | done 99 | 100 | TRAIN=$tmp/train.en-de 101 | BPE_CODE=$prep/code 102 | rm -f $TRAIN 103 | for l in $src $tgt; do 104 | cat $tmp/train.$l >> $TRAIN 105 | done 106 | 107 | echo "learn_bpe.py on ${TRAIN}..." 108 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 109 | 110 | for L in $src $tgt; do 111 | for f in train.$L valid.$L test.$L; do 112 | echo "apply_bpe.py to ${f}..." 113 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f 114 | done 115 | done 116 | -------------------------------------------------------------------------------- /examples/wav2vec/README.md: -------------------------------------------------------------------------------- 1 | # wav2vec 2 | 3 | Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). 4 | 5 | ## Training a new model with the CLI tools 6 | 7 | Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) 8 | 9 | ### Prepare training data manifest: 10 | 11 | ``` 12 | $ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav 13 | ``` 14 | 15 | ### Train a wav2vec model: 16 | 17 | ``` 18 | $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ 19 | --arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ 20 | --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ 21 | --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ 22 | --skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \ 23 | --max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test 24 | ``` 25 | 26 | ### Extract embeddings from the downstream task data: 27 | 28 | ``` 29 | $ PYTHONPATH /path/to/fairseq python scripts/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ 30 | --model /model/path/checkpoint_best.pt --split train valid test 31 | ``` 32 | -------------------------------------------------------------------------------- /fairseq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/fairseq/7efde2261f78e9e8d20e637e252bdd9977ec9290/fairseq.gif -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | __all__ = ['pdb'] 9 | __version__ = '0.7.2' 10 | 11 | import fairseq.criterions 12 | import fairseq.models 13 | import fairseq.modules 14 | import fairseq.optim 15 | import fairseq.optim.lr_scheduler 16 | import fairseq.pdb 17 | import fairseq.tasks 18 | -------------------------------------------------------------------------------- /fairseq/binarizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from collections import Counter 9 | import os 10 | 11 | from fairseq.tokenizer import tokenize_line 12 | 13 | 14 | def safe_readline(f): 15 | pos = f.tell() 16 | while True: 17 | try: 18 | return f.readline() 19 | except UnicodeDecodeError: 20 | pos -= 1 21 | f.seek(pos) # search where this character begins 22 | 23 | 24 | class Binarizer: 25 | 26 | @staticmethod 27 | def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, 28 | offset=0, end=-1): 29 | nseq, ntok = 0, 0 30 | replaced = Counter() 31 | 32 | def replaced_consumer(word, idx): 33 | if idx == dict.unk_index and word != dict.unk_word: 34 | replaced.update([word]) 35 | 36 | with open(filename, 'r', encoding='utf-8') as f: 37 | f.seek(offset) 38 | # next(f) breaks f.tell(), hence readline() must be used 39 | line = safe_readline(f) 40 | while line: 41 | if end > 0 and f.tell() > end: 42 | break 43 | ids = dict.encode_line( 44 | line=line, 45 | line_tokenizer=tokenize, 46 | add_if_not_exist=False, 47 | consumer=replaced_consumer, 48 | append_eos=append_eos, 49 | reverse_order=reverse_order, 50 | ) 51 | nseq += 1 52 | ntok += len(ids) 53 | consumer(ids) 54 | line = f.readline() 55 | return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} 56 | 57 | @staticmethod 58 | def find_offsets(filename, num_chunks): 59 | with open(filename, 'r', encoding='utf-8') as f: 60 | size = os.fstat(f.fileno()).st_size 61 | chunk_size = size // num_chunks 62 | offsets = [0 for _ in range(num_chunks + 1)] 63 | for i in range(1, num_chunks): 64 | f.seek(chunk_size * i) 65 | safe_readline(f) 66 | offsets[i] = f.tell() 67 | return offsets 68 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/libbleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <map> 10 | #include <array> 11 | #include <cstring> 12 | #include <cstdio> 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, size_t n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += predlen; 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map<size_t, size_t> count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 0; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 0; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <Python.h> 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from fairseq import registry 12 | from fairseq.criterions.fairseq_criterion import FairseqCriterion 13 | 14 | 15 | build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( 16 | '--criterion', 17 | base_class=FairseqCriterion, 18 | default='cross_entropy', 19 | ) 20 | 21 | 22 | # automatically import any Python files in the criterions/ directory 23 | for file in os.listdir(os.path.dirname(__file__)): 24 | if file.endswith('.py') and not file.startswith('_'): 25 | module = file[:file.find('.py')] 26 | importlib.import_module('fairseq.criterions.' + module) 27 | -------------------------------------------------------------------------------- /fairseq/criterions/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | 14 | from . import FairseqCriterion, register_criterion 15 | 16 | 17 | @register_criterion('binary_cross_entropy') 18 | class BinaryCrossEntropyCriterion(FairseqCriterion): 19 | 20 | def __init__(self, args, task): 21 | super().__init__(args, task) 22 | 23 | def forward(self, model, sample, reduce=True): 24 | """Compute the loss for the given sample. 25 | 26 | Returns a tuple with three elements: 27 | 1) the loss 28 | 2) the sample size, which is used as the denominator for the gradient 29 | 3) logging outputs to display while training 30 | """ 31 | net_output = model(**sample['net_input']) 32 | logits = model.get_logits(net_output).float() 33 | target = model.get_targets(sample, net_output, expand_steps=False).float() 34 | 35 | if hasattr(model, 'get_target_weights'): 36 | weights = model.get_target_weights(target, net_output) 37 | if torch.is_tensor(weights): 38 | weights = weights.float() 39 | else: 40 | weights = 1. 41 | 42 | loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False) 43 | 44 | loss = loss * weights 45 | 46 | if reduce: 47 | loss = loss.sum() 48 | 49 | sample_size = target.numel() 50 | logging_output = { 51 | 'loss': utils.item(loss.data) if reduce else loss.data, 52 | 'ntokens': sample_size, 53 | 'nsentences': logits.size(0), 54 | 'sample_size': sample_size, 55 | } 56 | return loss, sample_size, logging_output 57 | 58 | @staticmethod 59 | def aggregate_logging_outputs(logging_outputs): 60 | """Aggregate logging outputs from data parallel training.""" 61 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 62 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 63 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 64 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 65 | agg_output = { 66 | 'loss': loss_sum / sample_size / math.log(2), 67 | 'ntokens': ntokens, 68 | 'nsentences': nsentences, 69 | 'sample_size': sample_size, 70 | } 71 | if sample_size != ntokens: 72 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 73 | return agg_output -------------------------------------------------------------------------------- /fairseq/criterions/composite_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch import nn 9 | 10 | from fairseq import utils 11 | from . import FairseqCriterion, register_criterion 12 | 13 | 14 | @register_criterion('composite_loss') 15 | class CompositeLoss(FairseqCriterion): 16 | """This is a composite loss that, given a list of model outputs and a list of targets, 17 | computes an average of losses for each output-target pair""" 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add criterion-specific arguments to the parser.""" 22 | # fmt: off 23 | parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, 24 | help='underlying criterion to use for the composite loss') 25 | # fmt: on 26 | 27 | @staticmethod 28 | def build_underlying_criterion(args, task): 29 | saved_criterion = args.criterion 30 | args.criterion = args.underlying_criterion 31 | assert saved_criterion != args.underlying_criterion 32 | underlying_criterion = task.build_criterion(args) 33 | args.criterion = saved_criterion 34 | return underlying_criterion 35 | 36 | @classmethod 37 | def build_criterion(cls, args, task): 38 | underlying_criterion = CompositeLoss.build_underlying_criterion(args, task) 39 | 40 | class FakeModel(nn.Module): 41 | 42 | def __init__(self, model, net_out, target): 43 | super().__init__() 44 | self.model = model 45 | self.net_out = net_out 46 | self.target = target 47 | 48 | def forward(self, **unused): 49 | return self.net_out 50 | 51 | def get_normalized_probs(self, net_output, log_probs, sample=None): 52 | return self.model.get_normalized_probs(net_output, log_probs, sample=sample) 53 | 54 | def get_targets(self, *unused): 55 | return self.target 56 | 57 | @property 58 | def decoder(self): 59 | return self.model.decoder 60 | 61 | class _CompositeLoss(FairseqCriterion): 62 | 63 | def __init__(self, args, task, underlying_criterion): 64 | super().__init__(args, task) 65 | self.underlying_criterion = underlying_criterion 66 | 67 | def forward(self, model, sample, reduce=True): 68 | net_outputs = model(**sample['net_input']) 69 | targets = sample['target'] 70 | 71 | bsz = targets[0].size(0) 72 | loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_() 73 | 74 | sample_size = 0 75 | logging_output = {} 76 | for o, t in zip(net_outputs[0], targets): 77 | m = FakeModel(model, (o, net_outputs[1]), t) 78 | sample['target'] = t 79 | l, ss, logging_output = self.underlying_criterion(m, sample, reduce) 80 | loss += l 81 | sample_size += ss 82 | 83 | loss.div_(len(targets)) 84 | sample_size /= len(targets) 85 | 86 | logging_output['loss'] = utils.item(loss.data) if reduce else loss.data 87 | return loss, sample_size, logging_output 88 | 89 | @staticmethod 90 | def aggregate_logging_outputs(logging_outputs): 91 | return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs) 92 | 93 | return _CompositeLoss(args, task, underlying_criterion) 94 | -------------------------------------------------------------------------------- /fairseq/criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('cross_entropy') 17 | class CrossEntropyCriterion(FairseqCriterion): 18 | 19 | def __init__(self, args, task): 20 | super().__init__(args, task) 21 | 22 | def forward(self, model, sample, reduce=True): 23 | """Compute the loss for the given sample. 24 | 25 | Returns a tuple with three elements: 26 | 1) the loss 27 | 2) the sample size, which is used as the denominator for the gradient 28 | 3) logging outputs to display while training 29 | """ 30 | net_output = model(**sample['net_input']) 31 | loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) 32 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 33 | logging_output = { 34 | 'loss': utils.item(loss.data) if reduce else loss.data, 35 | 'ntokens': sample['ntokens'], 36 | 'nsentences': sample['target'].size(0), 37 | 'sample_size': sample_size, 38 | } 39 | return loss, sample_size, logging_output 40 | 41 | def compute_loss(self, model, net_output, sample, reduce=True): 42 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 43 | lprobs = lprobs.view(-1, lprobs.size(-1)) 44 | target = model.get_targets(sample, net_output).view(-1) 45 | loss = F.nll_loss( 46 | lprobs, 47 | target, 48 | ignore_index=self.padding_idx, 49 | reduction='sum' if reduce else 'none', 50 | ) 51 | return loss, loss 52 | 53 | @staticmethod 54 | def aggregate_logging_outputs(logging_outputs): 55 | """Aggregate logging outputs from data parallel training.""" 56 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 57 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 58 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 59 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 60 | agg_output = { 61 | 'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0., 62 | 'ntokens': ntokens, 63 | 'nsentences': nsentences, 64 | 'sample_size': sample_size, 65 | } 66 | if sample_size != ntokens: 67 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 68 | return agg_output 69 | -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.nn.modules.loss import _Loss 9 | 10 | 11 | class FairseqCriterion(_Loss): 12 | 13 | def __init__(self, args, task): 14 | super().__init__() 15 | self.args = args 16 | self.task = task 17 | self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add criterion-specific arguments to the parser.""" 22 | pass 23 | 24 | @classmethod 25 | def build_criterion(cls, args, task): 26 | return cls(args, task) 27 | 28 | def forward(self, model, sample, reduce=True): 29 | """Compute the loss for the given sample. 30 | 31 | Returns a tuple with three elements: 32 | 1) the loss 33 | 2) the sample size, which is used as the denominator for the gradient 34 | 3) logging outputs to display while training 35 | """ 36 | raise NotImplementedError 37 | 38 | @staticmethod 39 | def aggregate_logging_outputs(logging_outputs): 40 | """Aggregate logging outputs from data parallel training.""" 41 | raise NotImplementedError 42 | 43 | @staticmethod 44 | def grad_denom(sample_sizes): 45 | """Compute the gradient denominator for a set of sample sizes.""" 46 | return sum(sample_sizes) 47 | -------------------------------------------------------------------------------- /fairseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .dictionary import Dictionary, TruncatedDictionary 9 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary 10 | 11 | from .fairseq_dataset import FairseqDataset 12 | 13 | from .audio.raw_audio_dataset import RawAudioDataset 14 | from .backtranslation_dataset import BacktranslationDataset 15 | from .block_pair_dataset import BlockPairDataset 16 | from .concat_dataset import ConcatDataset 17 | from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset 18 | from .language_pair_dataset import LanguagePairDataset 19 | from .lm_context_window_dataset import LMContextWindowDataset 20 | from .masked_lm_dataset import MaskedLMDataset 21 | from .monolingual_dataset import MonolingualDataset 22 | from .noising import NoisingDataset 23 | from .round_robin_zip_datasets import RoundRobinZipDatasets 24 | from .token_block_dataset import TokenBlockDataset 25 | from .transform_eos_dataset import TransformEosDataset 26 | from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset 27 | 28 | from .iterators import ( 29 | CountingIterator, 30 | EpochBatchIterator, 31 | GroupedIterator, 32 | ShardedIterator, 33 | ) 34 | 35 | __all__ = [ 36 | 'BacktranslationDataset', 37 | 'BertDictionary', 38 | 'BlockPairDataset', 39 | 'ConcatDataset', 40 | 'CountingIterator', 41 | 'Dictionary', 42 | 'EpochBatchIterator', 43 | 'FairseqDataset', 44 | 'GroupedIterator', 45 | 'IndexedCachedDataset', 46 | 'IndexedDataset', 47 | 'IndexedRawTextDataset', 48 | 'LanguagePairDataset', 49 | 'LMContextWindowDataset', 50 | 'MaskedLMDataset', 51 | 'MaskedLMDictionary', 52 | 'MMapIndexedDataset', 53 | 'MonolingualDataset', 54 | 'NoisingDataset', 55 | 'RawAudioDataset', 56 | 'RoundRobinZipDatasets', 57 | 'ShardedIterator', 58 | 'TokenBlockDataset', 59 | 'TransformEosDataset', 60 | 'TransformEosLangPairDataset', 61 | 'TruncatedDictionary', 62 | ] 63 | -------------------------------------------------------------------------------- /fairseq/data/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/fairseq/7efde2261f78e9e8d20e637e252bdd9977ec9290/fairseq/data/audio/__init__.py -------------------------------------------------------------------------------- /fairseq/data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import bisect 9 | 10 | import numpy as np 11 | 12 | from . import FairseqDataset 13 | 14 | 15 | class ConcatDataset(FairseqDataset): 16 | @staticmethod 17 | def cumsum(sequence, sample_ratios): 18 | r, s = [], 0 19 | for e, ratio in zip(sequence, sample_ratios): 20 | curr_len = int(ratio * len(e)) 21 | r.append(curr_len + s) 22 | s += curr_len 23 | return r 24 | 25 | def __init__(self, datasets, sample_ratios=1): 26 | super(ConcatDataset, self).__init__() 27 | assert len(datasets) > 0, "datasets should not be an empty iterable" 28 | self.datasets = list(datasets) 29 | if isinstance(sample_ratios, int): 30 | sample_ratios = [sample_ratios] * len(self.datasets) 31 | self.sample_ratios = sample_ratios 32 | self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) 33 | self.real_sizes = [len(d) for d in self.datasets] 34 | 35 | def __len__(self): 36 | return self.cumulative_sizes[-1] 37 | 38 | def __getitem__(self, idx): 39 | dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) 40 | return self.datasets[dataset_idx][sample_idx] 41 | 42 | def _get_dataset_and_sample_index(self, idx: int): 43 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 44 | if dataset_idx == 0: 45 | sample_idx = idx 46 | else: 47 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 48 | sample_idx = sample_idx % self.real_sizes[dataset_idx] 49 | return dataset_idx, sample_idx 50 | 51 | def collater(self, samples): 52 | # For now only supports datasets with same underlying collater implementations 53 | return self.datasets[0].collater(samples) 54 | 55 | def size(self, idx: int): 56 | """ 57 | Return an example's size as a float or tuple. 58 | """ 59 | dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) 60 | return self.datasets[dataset_idx].size(sample_idx) 61 | 62 | def num_tokens(self, index: int): 63 | return np.max(self.size(index)) 64 | 65 | @property 66 | def sizes(self): 67 | return np.concatenate( 68 | [np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)] 69 | ) 70 | 71 | @property 72 | def supports_prefetch(self): 73 | return all(d.supports_prefetch for d in self.datasets) 74 | 75 | def ordered_indices(self): 76 | """ 77 | Returns indices sorted by length. So less padding is needed. 78 | """ 79 | return np.argsort(self.sizes) 80 | 81 | def prefetch(self, indices): 82 | frm = 0 83 | for to, ds in zip(self.cumulative_sizes, self.datasets): 84 | real_size = len(ds) 85 | if getattr(ds, 'supports_prefetch', False): 86 | ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) 87 | frm = to 88 | -------------------------------------------------------------------------------- /fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.utils.data 9 | 10 | 11 | class FairseqDataset(torch.utils.data.Dataset): 12 | """A dataset that provides helpers for batching.""" 13 | 14 | def __getitem__(self, index): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | def collater(self, samples): 21 | """Merge a list of samples to form a mini-batch. 22 | 23 | Args: 24 | samples (List[dict]): samples to collate 25 | 26 | Returns: 27 | dict: a mini-batch suitable for forwarding with a Model 28 | """ 29 | raise NotImplementedError 30 | 31 | def num_tokens(self, index): 32 | """Return the number of tokens in a sample. This value is used to 33 | enforce ``--max-tokens`` during batching.""" 34 | raise NotImplementedError 35 | 36 | def size(self, index): 37 | """Return an example's size as a float or tuple. This value is used when 38 | filtering a dataset with ``--max-positions``.""" 39 | raise NotImplementedError 40 | 41 | def ordered_indices(self): 42 | """Return an ordered list of indices. Batches will be constructed based 43 | on this order.""" 44 | raise NotImplementedError 45 | 46 | @property 47 | def supports_prefetch(self): 48 | """Whether this dataset supports prefetching.""" 49 | return False 50 | 51 | def prefetch(self, indices): 52 | """Prefetch the data required for this epoch.""" 53 | raise NotImplementedError 54 | -------------------------------------------------------------------------------- /fairseq/data/lm_context_window_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from fairseq.data.monolingual_dataset import MonolingualDataset 12 | 13 | from . import FairseqDataset 14 | 15 | 16 | class LMContextWindowDataset(FairseqDataset): 17 | """Wraps a MonolingualDataset and provides more context for evaluation.""" 18 | 19 | def __init__(self, dataset, tokens_per_sample, context_window, pad_idx): 20 | assert isinstance(dataset, MonolingualDataset) 21 | assert context_window > 0 22 | self.dataset = dataset 23 | self.tokens_per_sample = tokens_per_sample 24 | self.context_window = context_window 25 | self.pad_idx = pad_idx 26 | self.prev_tokens = np.empty([0]) 27 | 28 | def __getitem__(self, index): 29 | return self.dataset[index] 30 | 31 | def __len__(self): 32 | return len(self.dataset) 33 | 34 | def collater(self, samples): 35 | sample = self.dataset.collater(samples) 36 | 37 | pad = self.pad_idx 38 | max_sample_len = self.tokens_per_sample + self.context_window 39 | 40 | bsz, tsz = sample['net_input']['src_tokens'].shape 41 | start_idxs = [0] * bsz 42 | toks = sample['net_input']['src_tokens'] 43 | lengths = sample['net_input']['src_lengths'] 44 | tgt = sample['target'] 45 | new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64) 46 | new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64) 47 | sample_lens = toks.ne(pad).long().sum(dim=1).cpu() 48 | for i in range(bsz): 49 | sample_len = sample_lens[i] 50 | extra = len(self.prev_tokens) + sample_len - max_sample_len 51 | if extra > 0: 52 | self.prev_tokens = self.prev_tokens[extra:] 53 | pads = np.full(self.context_window - len(self.prev_tokens), pad) 54 | new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads]) 55 | new_tgt[i, len(self.prev_tokens):len(self.prev_tokens) + len(tgt[i])] = tgt[i] 56 | start_idxs[i] = len(self.prev_tokens) 57 | lengths[i] += len(self.prev_tokens) 58 | self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window:] 59 | sample['net_input']['src_tokens'] = torch.from_numpy(new_toks) 60 | sample['target'] = torch.from_numpy(new_tgt) 61 | sample['start_indices'] = start_idxs 62 | 63 | return sample 64 | 65 | def num_tokens(self, index): 66 | return self.dataset.num_tokens(index) 67 | 68 | def size(self, index): 69 | return self.dataset.size(index) 70 | 71 | def ordered_indices(self): 72 | # NOTE we don't shuffle the data to retain access to the previous dataset elements 73 | return np.arange(len(self.dataset)) 74 | 75 | @property 76 | def supports_prefetch(self): 77 | return getattr(self.dataset, 'supports_prefetch', False) 78 | 79 | def prefetch(self, indices): 80 | return self.dataset.prefetch(indices) 81 | -------------------------------------------------------------------------------- /fairseq/data/masked_lm_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.data import Dictionary 9 | 10 | 11 | class MaskedLMDictionary(Dictionary): 12 | """ 13 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by 14 | adding the mask symbol. 15 | """ 16 | def __init__( 17 | self, 18 | pad='<pad>', 19 | eos='</s>', 20 | unk='<unk>', 21 | mask='<mask>', 22 | ): 23 | super().__init__(pad, eos, unk) 24 | self.mask_word = mask 25 | self.mask_index = self.add_symbol(mask) 26 | self.nspecial = len(self.symbols) 27 | 28 | def mask(self): 29 | """Helper to get index of mask symbol""" 30 | return self.mask_index 31 | 32 | 33 | class BertDictionary(MaskedLMDictionary): 34 | """ 35 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support 36 | for cls and sep symbols. 37 | """ 38 | def __init__( 39 | self, 40 | pad='<pad>', 41 | eos='</s>', 42 | unk='<unk>', 43 | mask='<mask>', 44 | cls='<cls>', 45 | sep='<sep>' 46 | ): 47 | super().__init__(pad, eos, unk, mask) 48 | self.cls_word = cls 49 | self.sep_word = sep 50 | self.cls_index = self.add_symbol(cls) 51 | self.sep_index = self.add_symbol(sep) 52 | self.nspecial = len(self.symbols) 53 | 54 | def cls(self): 55 | """Helper to get index of cls symbol""" 56 | return self.cls_index 57 | 58 | def sep(self): 59 | """Helper to get index of sep symbol""" 60 | return self.sep_index 61 | -------------------------------------------------------------------------------- /fairseq/data/transform_eos_lang_pair_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | from . import FairseqDataset 10 | from typing import Optional 11 | 12 | 13 | class TransformEosLangPairDataset(FairseqDataset): 14 | """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on 15 | collated samples of language pair dataset. 16 | 17 | Note that the transformation is applied in :func:`collater`. 18 | 19 | Args: 20 | dataset (~fairseq.data.FairseqDataset): dataset that collates sample into 21 | LanguagePairDataset schema 22 | src_eos (int): original source end-of-sentence symbol index to be replaced 23 | new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol 24 | tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced 25 | new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the 26 | beginning of 'prev_output_tokens' 27 | """ 28 | 29 | def __init__( 30 | self, 31 | dataset: FairseqDataset, 32 | src_eos: int, 33 | new_src_eos: Optional[int] = None, 34 | tgt_bos: Optional[int] = None, 35 | new_tgt_bos: Optional[int] = None, 36 | ): 37 | self.dataset = dataset 38 | self.src_eos = src_eos 39 | self.new_src_eos = new_src_eos 40 | self.tgt_bos = tgt_bos 41 | self.new_tgt_bos = new_tgt_bos 42 | 43 | def __getitem__(self, index): 44 | return self.dataset[index] 45 | 46 | def __len__(self): 47 | return len(self.dataset) 48 | 49 | def collater(self, samples): 50 | samples = self.dataset.collater(samples) 51 | 52 | # TODO: support different padding direction 53 | if self.new_src_eos is not None: 54 | assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0 55 | samples['net_input']['src_tokens'][:, -1] = self.new_src_eos 56 | 57 | if self.new_tgt_bos is not None: 58 | assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0 59 | samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos 60 | 61 | return samples 62 | 63 | def num_tokens(self, index): 64 | return self.dataset.num_tokens(index) 65 | 66 | def size(self, index): 67 | return self.dataset.size(index) 68 | 69 | def ordered_indices(self): 70 | return self.dataset.ordered_indices() 71 | 72 | @property 73 | def supports_prefetch(self): 74 | return getattr(self.dataset, 'supports_prefetch', False) 75 | 76 | def prefetch(self, indices): 77 | return self.dataset.prefetch(indices) 78 | -------------------------------------------------------------------------------- /fairseq/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import importlib 10 | import os 11 | 12 | from fairseq import registry 13 | 14 | 15 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( 16 | '--tokenizer', 17 | default=None, 18 | ) 19 | 20 | 21 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry( 22 | '--bpe', 23 | default=None, 24 | ) 25 | 26 | 27 | # automatically import any Python files in the transforms/ directory 28 | for file in os.listdir(os.path.dirname(__file__)): 29 | if file.endswith('.py') and not file.startswith('_'): 30 | module = file[:file.find('.py')] 31 | importlib.import_module('fairseq.data.transforms.' + module) 32 | -------------------------------------------------------------------------------- /fairseq/data/transforms/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.data.transforms import register_tokenizer 9 | 10 | 11 | @register_tokenizer('moses') 12 | class MosesTokenizer(object): 13 | 14 | @staticmethod 15 | def add_args(parser): 16 | # fmt: off 17 | parser.add_argument('--moses-source-lang', default='en', metavar='SRC', 18 | help='source language') 19 | parser.add_argument('--moses-target-lang', default='en', metavar='TARGET', 20 | help='target language') 21 | parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, 22 | help='don\'t apply dash split rules') 23 | parser.add_argument('--moses-no-escape', action='store_true', default=False, 24 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.') 25 | # fmt: on 26 | 27 | def __init__(self, args): 28 | self.args = args 29 | try: 30 | from sacremoses import MosesTokenizer, MosesDetokenizer 31 | self.tok = MosesTokenizer(args.moses_source_lang) 32 | self.detok = MosesDetokenizer(args.moses_target_lang) 33 | except ImportError: 34 | raise ImportError('Please install Moses tokenizer with: pip install sacremoses') 35 | 36 | def encode(self, x: str) -> str: 37 | return self.tok.tokenize( 38 | x, 39 | aggressive_dash_splits=(not self.args.moses_no_dash_splits), 40 | return_str=True, 41 | escape=(not self.args.moses_no_escape), 42 | ) 43 | 44 | def decode(self, x: str) -> str: 45 | return self.detok.detokenize(x.split()) 46 | -------------------------------------------------------------------------------- /fairseq/data/transforms/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.data.transforms import register_tokenizer 9 | 10 | 11 | @register_tokenizer('nltk') 12 | class NLTKTokenizer(object): 13 | 14 | def __init__(self, source_lang=None, target_lang=None): 15 | try: 16 | from nltk.tokenize import word_tokenize 17 | self.word_tokenize = word_tokenize 18 | except ImportError: 19 | raise ImportError('Please install nltk with: pip install nltk') 20 | 21 | def encode(self, x: str) -> str: 22 | return ' '.join(self.word_tokenize(x)) 23 | 24 | def decode(self, x: str) -> str: 25 | return x 26 | -------------------------------------------------------------------------------- /fairseq/data/transforms/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.transforms import register_bpe 10 | 11 | 12 | @register_bpe('sentencepiece') 13 | class SentencepieceBPE(object): 14 | 15 | @staticmethod 16 | def add_args(parser): 17 | # fmt: off 18 | parser.add_argument('--sentencepiece-vocab', type=str, 19 | help='path to sentencepiece vocab') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | vocab = file_utils.cached_path(args.sentencepiece_vocab) 24 | try: 25 | import sentencepiece as spm 26 | self.sp = spm.SentencePieceProcessor() 27 | self.sp.Load(vocab) 28 | except ImportError: 29 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 30 | 31 | def encode(self, x: str) -> str: 32 | return ' '.join(self.sp.EncodeAsPieces(x)) 33 | 34 | def decode(self, x: str) -> str: 35 | return x.replace(' ', '').replace('\u2581', ' ').strip() 36 | -------------------------------------------------------------------------------- /fairseq/data/transforms/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import re 9 | 10 | from fairseq.data.transforms import register_tokenizer 11 | 12 | 13 | @register_tokenizer('space') 14 | class SpaceTokenizer(object): 15 | 16 | def __init__(self, source_lang=None, target_lang=None): 17 | self.space_tok = re.compile(r"\s+") 18 | 19 | def encode(self, x: str) -> str: 20 | return self.space_tok.sub(' ', x) 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /fairseq/data/transforms/subword_nmt_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.transforms import register_bpe 10 | 11 | 12 | @register_bpe('subword_nmt') 13 | class SubwordNMTBPE(object): 14 | 15 | @staticmethod 16 | def add_args(parser): 17 | # fmt: off 18 | parser.add_argument('--bpe-codes', type=str, 19 | help='path to subword NMT BPE') 20 | parser.add_argument('--bpe-separator', default='@@', 21 | help='BPE separator') 22 | # fmt: on 23 | 24 | def __init__(self, args): 25 | if args.bpe_codes is None: 26 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt') 27 | codes = file_utils.cached_path(args.bpe_codes) 28 | try: 29 | from subword_nmt import apply_bpe 30 | bpe_parser = apply_bpe.create_parser() 31 | bpe_args = bpe_parser.parse_args([ 32 | '--codes', codes, 33 | '--separator', args.bpe_separator, 34 | ]) 35 | self.bpe = apply_bpe.BPE( 36 | bpe_args.codes, 37 | bpe_args.merges, 38 | bpe_args.separator, 39 | None, 40 | bpe_args.glossaries, 41 | ) 42 | self.bpe_symbol = bpe_args.separator + ' ' 43 | except ImportError: 44 | raise ImportError('Please install subword_nmt with: pip install subword-nmt') 45 | 46 | def encode(self, x: str) -> str: 47 | return self.bpe.process_line(x) 48 | 49 | def decode(self, x: str) -> str: 50 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 51 | -------------------------------------------------------------------------------- /fairseq/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import time 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | class TimeMeter(object): 30 | """Computes the average occurrence of some event per second""" 31 | def __init__(self, init=0): 32 | self.reset(init) 33 | 34 | def reset(self, init=0): 35 | self.init = init 36 | self.start = time.time() 37 | self.n = 0 38 | 39 | def update(self, val=1): 40 | self.n += val 41 | 42 | @property 43 | def avg(self): 44 | return self.n / self.elapsed_time 45 | 46 | @property 47 | def elapsed_time(self): 48 | return self.init + (time.time() - self.start) 49 | 50 | 51 | class StopwatchMeter(object): 52 | """Computes the sum/avg duration of some event in seconds""" 53 | def __init__(self): 54 | self.reset() 55 | 56 | def start(self): 57 | self.start_time = time.time() 58 | 59 | def stop(self, n=1): 60 | if self.start_time is not None: 61 | delta = time.time() - self.start_time 62 | self.sum += delta 63 | self.n += n 64 | self.start_time = None 65 | 66 | def reset(self): 67 | self.sum = 0 68 | self.n = 0 69 | self.start_time = None 70 | 71 | @property 72 | def avg(self): 73 | return self.sum / self.n 74 | -------------------------------------------------------------------------------- /fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.models import FairseqEncoder 9 | 10 | 11 | class CompositeEncoder(FairseqEncoder): 12 | """ 13 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 14 | 15 | We run forward on each encoder and return a dictionary of outputs. The first 16 | encoder's dictionary is used for initialization. 17 | 18 | Args: 19 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 20 | """ 21 | 22 | def __init__(self, encoders): 23 | super().__init__(next(iter(encoders.values())).dictionary) 24 | self.encoders = encoders 25 | for key in self.encoders: 26 | self.add_module(key, self.encoders[key]) 27 | 28 | def forward(self, src_tokens, src_lengths): 29 | """ 30 | Args: 31 | src_tokens (LongTensor): tokens in the source language of shape 32 | `(batch, src_len)` 33 | src_lengths (LongTensor): lengths of each source sentence of shape 34 | `(batch)` 35 | 36 | Returns: 37 | dict: 38 | the outputs from each Encoder 39 | """ 40 | encoder_out = {} 41 | for key in self.encoders: 42 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 43 | return encoder_out 44 | 45 | def reorder_encoder_out(self, encoder_out, new_order): 46 | """Reorder encoder output according to new_order.""" 47 | for key in self.encoders: 48 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) 49 | return encoder_out 50 | 51 | def max_positions(self): 52 | return min([self.encoders[key].max_positions() for key in self.encoders]) 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | for key in self.encoders: 56 | self.encoders[key].upgrade_state_dict(state_dict) 57 | return state_dict 58 | -------------------------------------------------------------------------------- /fairseq/models/distributed_fairseq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import inspect 9 | 10 | from torch.nn import parallel 11 | 12 | from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel 13 | from fairseq.models import BaseFairseqModel 14 | 15 | 16 | def DistributedFairseqModel(args, model): 17 | """ 18 | Wrap a *model* to support distributed data parallel training. 19 | 20 | This is similar to the built-in DistributedDataParallel, but allows 21 | additional configuration of the DistributedDataParallel class to 22 | use, and also provides easier access to the wrapped model by 23 | forwarding requests for missing attributes to the wrapped model. 24 | 25 | Args: 26 | args (argparse.Namespace): fairseq args 27 | model (BaseFairseqModel): model to wrap 28 | """ 29 | # determine which DDP class to extend 30 | assert isinstance(model, BaseFairseqModel) 31 | if args.ddp_backend == 'c10d': 32 | ddp_class = parallel.DistributedDataParallel 33 | init_kwargs = dict( 34 | module=model, 35 | device_ids=[args.device_id], 36 | output_device=args.device_id, 37 | broadcast_buffers=False, 38 | bucket_cap_mb=args.bucket_cap_mb, 39 | ) 40 | # Maintain backward compatibility 41 | if 'check_reduction' in inspect.getargspec(ddp_class)[0]: 42 | init_kwargs['check_reduction'] = True 43 | if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]: 44 | init_kwargs['find_unused_parameters'] = args.find_unused_parameters 45 | elif args.ddp_backend == 'no_c10d': 46 | ddp_class = LegacyDistributedDataParallel 47 | init_kwargs = dict( 48 | module=model, 49 | world_size=args.distributed_world_size, 50 | buffer_size=2**28, 51 | ) 52 | else: 53 | raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) 54 | 55 | class _DistributedFairseqModel(ddp_class): 56 | """Extend DistributedDataParallel to check for missing 57 | attributes in the wrapped module.""" 58 | 59 | def __init__(self, *args, **kwargs): 60 | super().__init__(*args, **kwargs) 61 | 62 | def __getattr__(self, name): 63 | wrapped_module = super().__getattr__('module') 64 | if hasattr(wrapped_module, name): 65 | return getattr(wrapped_module, name) 66 | return super().__getattr__(name) 67 | 68 | return _DistributedFairseqModel(**init_kwargs) 69 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class FairseqDecoder(nn.Module): 14 | """Base class for decoders.""" 15 | 16 | def __init__(self, dictionary): 17 | super().__init__() 18 | self.dictionary = dictionary 19 | self.onnx_trace = False 20 | 21 | def forward(self, prev_output_tokens, encoder_out=None, **kwargs): 22 | """ 23 | Args: 24 | prev_output_tokens (LongTensor): shifted output tokens of shape 25 | `(batch, tgt_len)`, for input feeding/teacher forcing 26 | encoder_out (dict, optional): output from the encoder, used for 27 | encoder-side attention 28 | 29 | Returns: 30 | tuple: 31 | - the decoder's output of shape `(batch, tgt_len, vocab)` 32 | - a dictionary with any model-specific outputs 33 | """ 34 | x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs) 35 | x = self.output_layer(x) 36 | return x, extra 37 | 38 | def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): 39 | """ 40 | Returns: 41 | tuple: 42 | - the decoder's features of shape `(batch, tgt_len, embed_dim)` 43 | - a dictionary with any model-specific outputs 44 | """ 45 | raise NotImplementedError 46 | 47 | def output_layer(self, features, **kwargs): 48 | """ 49 | Project features to the default output size, e.g., vocabulary size. 50 | 51 | Args: 52 | features (Tensor): features returned by *extract_features*. 53 | """ 54 | raise NotImplementedError 55 | 56 | def get_normalized_probs(self, net_output, log_probs, sample): 57 | """Get normalized probabilities (or log probs) from a net's output.""" 58 | 59 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 60 | if sample is not None: 61 | assert 'target' in sample 62 | target = sample['target'] 63 | else: 64 | target = None 65 | out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) 66 | return out.exp_() if not log_probs else out 67 | 68 | logits = net_output[0] 69 | if log_probs: 70 | return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) 71 | else: 72 | return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) 73 | 74 | def max_positions(self): 75 | """Maximum input length supported by the decoder.""" 76 | return 1e6 # an arbitrary large number 77 | 78 | def upgrade_state_dict(self, state_dict): 79 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 80 | return state_dict 81 | 82 | def prepare_for_onnx_export_(self): 83 | self.onnx_trace = True 84 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class FairseqEncoder(nn.Module): 12 | """Base class for encoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__() 16 | self.dictionary = dictionary 17 | 18 | def forward(self, src_tokens, src_lengths=None, **kwargs): 19 | """ 20 | Args: 21 | src_tokens (LongTensor): tokens in the source language of shape 22 | `(batch, src_len)` 23 | src_lengths (LongTensor): lengths of each source sentence of shape 24 | `(batch)` 25 | """ 26 | raise NotImplementedError 27 | 28 | def reorder_encoder_out(self, encoder_out, new_order): 29 | """ 30 | Reorder encoder output according to `new_order`. 31 | 32 | Args: 33 | encoder_out: output from the ``forward()`` method 34 | new_order (LongTensor): desired order 35 | 36 | Returns: 37 | `encoder_out` rearranged according to `new_order` 38 | """ 39 | raise NotImplementedError 40 | 41 | def max_positions(self): 42 | """Maximum input length supported by the encoder.""" 43 | return 1e6 # an arbitrary large number 44 | 45 | def upgrade_state_dict(self, state_dict): 46 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 47 | return state_dict 48 | -------------------------------------------------------------------------------- /fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .adaptive_input import AdaptiveInput 9 | from .adaptive_softmax import AdaptiveSoftmax 10 | from .beamable_mm import BeamableMM 11 | from .character_token_embedder import CharacterTokenEmbedder 12 | from .conv_tbc import ConvTBC 13 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 14 | from .dynamic_convolution import DynamicConv1dTBC 15 | from .gelu import gelu, gelu_accurate 16 | from .grad_multiply import GradMultiply 17 | from .highway import Highway 18 | from .layer_norm import LayerNorm 19 | from .learned_positional_embedding import LearnedPositionalEmbedding 20 | from .lightweight_convolution import LightweightConv1dTBC 21 | from .linearized_convolution import LinearizedConvolution 22 | from .logsumexp_moe import LogSumExpMoE 23 | from .mean_pool_gating_network import MeanPoolGatingNetwork 24 | from .multihead_attention import MultiheadAttention 25 | from .positional_embedding import PositionalEmbedding 26 | from .scalar_bias import ScalarBias 27 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 28 | from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer 29 | from .transformer_sentence_encoder import TransformerSentenceEncoder 30 | from .unfold import unfold1d 31 | 32 | __all__ = [ 33 | 'AdaptiveInput', 34 | 'AdaptiveSoftmax', 35 | 'BeamableMM', 36 | 'CharacterTokenEmbedder', 37 | 'ConvTBC', 38 | 'DownsampledMultiHeadAttention', 39 | 'DynamicConv1dTBC', 40 | 'gelu', 41 | 'gelu_accurate', 42 | 'GradMultiply', 43 | 'Highway', 44 | 'LayerNorm', 45 | 'LearnedPositionalEmbedding', 46 | 'LightweightConv1dTBC', 47 | 'LinearizedConvolution', 48 | 'LogSumExpMoE', 49 | 'MeanPoolGatingNetwork', 50 | 'MultiheadAttention', 51 | 'PositionalEmbedding', 52 | 'ScalarBias', 53 | 'SinusoidalPositionalEmbedding', 54 | 'TransformerSentenceEncoderLayer', 55 | 'TransformerSentenceEncoder', 56 | 'unfold1d', 57 | ] 58 | -------------------------------------------------------------------------------- /fairseq/modules/adaptive_input.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from typing import List 13 | 14 | 15 | class AdaptiveInput(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | vocab_size: int, 20 | padding_idx: int, 21 | initial_dim: int, 22 | factor: float, 23 | output_dim: int, 24 | cutoff: List[int], 25 | ): 26 | super().__init__() 27 | 28 | if vocab_size > cutoff[-1]: 29 | cutoff = cutoff + [vocab_size] 30 | else: 31 | assert vocab_size == cutoff[ 32 | -1], 'cannot specify cutoff larger than vocab size' 33 | 34 | self.cutoff = cutoff 35 | self.embedding_dim = output_dim 36 | self.padding_idx = padding_idx 37 | 38 | self.embeddings = nn.ModuleList() 39 | for i in range(len(self.cutoff)): 40 | prev = self.cutoff[i - 1] if i > 0 else 0 41 | size = self.cutoff[i] - prev 42 | dim = int(initial_dim // (factor ** i)) 43 | seq = nn.Sequential( 44 | nn.Embedding(size, dim, padding_idx), 45 | nn.Linear(dim, output_dim, bias=False) 46 | ) 47 | self.embeddings.append(seq) 48 | 49 | def init_weights(m): 50 | if isinstance(m, nn.Embedding): 51 | nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) 52 | nn.init.constant_(m.weight[padding_idx], 0) 53 | elif hasattr(m, 'weight'): 54 | nn.init.xavier_uniform_(m.weight) 55 | 56 | self.apply(init_weights) 57 | 58 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 59 | 60 | def weights_for_band(self, band: int): 61 | return self.embeddings[band][0].weight, self.embeddings[band][1].weight 62 | 63 | def forward(self, input: torch.Tensor): 64 | result = self._float_tensor.new(input.shape + (self.embedding_dim,)) 65 | for i in range(len(self.cutoff)): 66 | mask = input.lt(self.cutoff[i]) 67 | if i > 0: 68 | mask.mul_(input.ge(self.cutoff[i - 1])) 69 | chunk_input = input[mask] - self.cutoff[i - 1] 70 | else: 71 | chunk_input = input[mask] 72 | if mask.any(): 73 | result[mask] = self.embeddings[i](chunk_input) 74 | return result 75 | -------------------------------------------------------------------------------- /fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class BeamableMM(nn.Module): 13 | """This module provides an optimized MM for beam decoding with attention. 14 | 15 | It leverage the fact that the source-side of the input is replicated beam 16 | times and the target-side of the input is of width one. This layer speeds up 17 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 18 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 19 | """ 20 | def __init__(self, beam_size=None): 21 | super(BeamableMM, self).__init__() 22 | self.beam_size = beam_size 23 | 24 | def forward(self, input1, input2): 25 | if ( 26 | not self.training and # test mode 27 | self.beam_size is not None and # beam size is set 28 | input1.dim() == 3 and # only support batched input 29 | input1.size(1) == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch.nn.modules.utils import _single 10 | 11 | 12 | class ConvTBC(torch.nn.Module): 13 | """1D convolution over an input of shape (time x batch x channel) 14 | 15 | The implementation uses gemm to perform the convolution. This implementation 16 | is faster than cuDNN for small kernel sizes. 17 | """ 18 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 19 | super(ConvTBC, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _single(kernel_size) 23 | self.padding = _single(padding) 24 | 25 | self.weight = torch.nn.Parameter(torch.Tensor( 26 | self.kernel_size[0], in_channels, out_channels)) 27 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 28 | 29 | def forward(self, input): 30 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 31 | 32 | def __repr__(self): 33 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 34 | ', padding={padding}') 35 | if self.bias is None: 36 | s += ', bias=False' 37 | s += ')' 38 | return s.format(name=self.__class__.__name__, **self.__dict__) 39 | -------------------------------------------------------------------------------- /fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | """ 8 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 9 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 10 | """ 11 | 12 | import math 13 | 14 | import torch 15 | 16 | 17 | def gelu_accurate(x): 18 | if not hasattr(gelu_accurate, "_a"): 19 | gelu_accurate._a = math.sqrt(2 / math.pi) 20 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 21 | 22 | 23 | def gelu(x: torch.Tensor) -> torch.Tensor: 24 | if hasattr(torch.nn.functional, 'gelu'): 25 | return torch.nn.functional.gelu(x.float()).type_as(x) 26 | else: 27 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 28 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class GradMultiply(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, scale): 14 | ctx.scale = scale 15 | res = x.new(x) 16 | return res 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | return grad * ctx.scale, None 21 | -------------------------------------------------------------------------------- /fairseq/modules/highway.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | from torch import nn 11 | 12 | 13 | class Highway(torch.nn.Module): 14 | """ 15 | A `Highway layer <https://arxiv.org/abs/1505.00387>`_. 16 | Adopted from the AllenNLP implementation. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | input_dim: int, 22 | num_layers: int = 1 23 | ): 24 | super(Highway, self).__init__() 25 | self.input_dim = input_dim 26 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) 27 | for _ in range(num_layers)]) 28 | self.activation = nn.ReLU() 29 | 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | for layer in self.layers: 34 | # As per comment in AllenNLP: 35 | # We should bias the highway layer to just carry its input forward. We do that by 36 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 37 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 38 | # of the bias vector in each Linear layer. 39 | nn.init.constant_(layer.bias[self.input_dim:], 1) 40 | 41 | nn.init.constant_(layer.bias[:self.input_dim], 0) 42 | nn.init.xavier_normal_(layer.weight) 43 | 44 | def forward( 45 | self, 46 | x: torch.Tensor 47 | ): 48 | for layer in self.layers: 49 | projection = layer(x) 50 | proj_x, gate = projection.chunk(2, dim=-1) 51 | proj_x = self.activation(proj_x) 52 | gate = torch.sigmoid(gate) 53 | x = gate * x + (gate.new_tensor([1]) - gate) * proj_x 54 | return x 55 | -------------------------------------------------------------------------------- /fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 12 | if not export and torch.cuda.is_available(): 13 | try: 14 | from apex.normalization import FusedLayerNorm 15 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 16 | except ImportError: 17 | pass 18 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 19 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class LearnedPositionalEmbedding(nn.Embedding): 14 | """ 15 | This module learns positional embeddings up to a fixed maximum size. 16 | Padding ids are ignored by either offsetting based on padding_idx 17 | or by setting padding_idx to None and ensuring that the appropriate 18 | position ids are passed to the forward function. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | num_embeddings: int, 24 | embedding_dim: int, 25 | padding_idx: int, 26 | ): 27 | super().__init__(num_embeddings, embedding_dim, padding_idx) 28 | self.onnx_trace = False 29 | 30 | def forward(self, input, incremental_state=None, positions=None): 31 | """Input is expected to be of size [bsz x seqlen].""" 32 | assert ( 33 | (positions is None) or (self.padding_idx is None) 34 | ), "If positions is pre-computed then padding_idx should not be set." 35 | 36 | if positions is None: 37 | if incremental_state is not None: 38 | # positions is the same for every token when decoding a single step 39 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 40 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) 41 | else: 42 | positions = utils.make_positions( 43 | input.data, self.padding_idx, onnx_trace=self.onnx_trace, 44 | ) 45 | return super().forward(positions) 46 | 47 | def max_positions(self): 48 | """Maximum number of supported positions.""" 49 | if self.padding_idx is not None: 50 | return self.num_embeddings - self.padding_idx - 1 51 | else: 52 | return self.num_embeddings 53 | -------------------------------------------------------------------------------- /fairseq/modules/logsumexp_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class LogSumExpMoE(torch.autograd.Function): 12 | """Standard LogSumExp forward pass, but use *posterior* for the backward. 13 | 14 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" 15 | (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, logp, posterior, dim=-1): 20 | ctx.save_for_backward(posterior) 21 | ctx.dim = dim 22 | return torch.logsumexp(logp, dim=dim) 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | posterior, = ctx.saved_tensors 27 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior 28 | return grad_logp, None, None 29 | -------------------------------------------------------------------------------- /fairseq/modules/mean_pool_gating_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class MeanPoolGatingNetwork(torch.nn.Module): 13 | """A simple mean-pooling gating network for selecting experts. 14 | 15 | This module applies mean pooling over an encoder's output and returns 16 | reponsibilities for each expert. The encoder format is expected to match 17 | :class:`fairseq.models.transformer.TransformerEncoder`. 18 | """ 19 | 20 | def __init__(self, embed_dim, num_experts, dropout=None): 21 | super().__init__() 22 | self.embed_dim = embed_dim 23 | self.num_experts = num_experts 24 | 25 | self.fc1 = torch.nn.Linear(embed_dim, embed_dim) 26 | self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None 27 | self.fc2 = torch.nn.Linear(embed_dim, num_experts) 28 | 29 | def forward(self, encoder_out): 30 | if not ( 31 | isinstance(encoder_out, dict) 32 | and 'encoder_out' in encoder_out 33 | and 'encoder_padding_mask' in encoder_out 34 | and encoder_out['encoder_out'].size(2) == self.embed_dim 35 | ): 36 | raise ValueError('Unexpected format for encoder_out') 37 | 38 | # mean pooling over time 39 | encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T 40 | encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C 41 | if encoder_padding_mask is not None: 42 | encoder_out = encoder_out.clone() # required because of transpose above 43 | encoder_out[encoder_padding_mask] = 0 44 | ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True) 45 | x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out) 46 | else: 47 | x = torch.mean(encoder_out, dim=1) 48 | 49 | x = torch.tanh(self.fc1(x)) 50 | if self.dropout is not None: 51 | x = self.dropout(x) 52 | x = self.fc2(x) 53 | return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x) 54 | -------------------------------------------------------------------------------- /fairseq/modules/positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from .learned_positional_embedding import LearnedPositionalEmbedding 11 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 12 | 13 | 14 | def PositionalEmbedding( 15 | num_embeddings: int, 16 | embedding_dim: int, 17 | padding_idx: int, 18 | learned: bool = False, 19 | ): 20 | if learned: 21 | # if padding_idx is specified then offset the embedding ids by 22 | # this index and adjust num_embeddings appropriately 23 | # TODO: The right place for this offset would be inside 24 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation. 25 | if padding_idx is not None: 26 | num_embeddings = num_embeddings + padding_idx + 1 27 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 28 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 29 | if padding_idx is not None: 30 | nn.init.constant_(m.weight[padding_idx], 0) 31 | else: 32 | m = SinusoidalPositionalEmbedding( 33 | embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, 34 | ) 35 | return m 36 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | import torch 10 | 11 | 12 | class ScalarBias(torch.autograd.Function): 13 | """ 14 | Adds a vector of scalars, used in self-attention mechanism to allow 15 | the model to optionally attend to this vector instead of the past 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, input, dim, bias_init): 20 | size = list(input.size()) 21 | size[dim] += 1 22 | output = input.new(*size).fill_(bias_init) 23 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 24 | ctx.dim = dim 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad): 29 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 30 | 31 | 32 | def scalar_bias(input, dim, bias_init=0): 33 | return ScalarBias.apply(input, dim, bias_init) 34 | -------------------------------------------------------------------------------- /fairseq/modules/transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from fairseq.modules import ( 14 | LayerNorm, 15 | MultiheadAttention, 16 | ) 17 | 18 | 19 | class TransformerSentenceEncoderLayer(nn.Module): 20 | """ 21 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 22 | models. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | embedding_dim: float = 768, 28 | ffn_embedding_dim: float = 3072, 29 | num_attention_heads: float = 8, 30 | dropout: float = 0.1, 31 | attention_dropout: float = 0.1, 32 | activation_dropout: float = 0.1, 33 | activation_fn: str = 'relu', 34 | add_bias_kv: bool = False, 35 | add_zero_attn: bool = False, 36 | export: bool = False, 37 | ) -> None: 38 | 39 | super().__init__() 40 | # Initialize parameters 41 | self.embedding_dim = embedding_dim 42 | self.dropout = dropout 43 | self.activation_dropout = activation_dropout 44 | 45 | # Initialize blocks 46 | self.activation_fn = utils.get_activation_fn(activation_fn) 47 | self.self_attn = MultiheadAttention( 48 | self.embedding_dim, 49 | num_attention_heads, 50 | dropout=attention_dropout, 51 | add_bias_kv=add_bias_kv, 52 | add_zero_attn=add_zero_attn, 53 | self_attention=True 54 | ) 55 | 56 | # layer norm associated with the self attention layer 57 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) 58 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 59 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 60 | 61 | # layer norm associated with the position wise feed-forward NN 62 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) 63 | 64 | def forward( 65 | self, 66 | x: torch.Tensor, 67 | self_attn_mask: torch.Tensor = None, 68 | self_attn_padding_mask: torch.Tensor = None, 69 | ): 70 | """ 71 | LayerNorm is applied either before or after the self-attention/ffn 72 | modules similar to the original Transformer imlementation. 73 | """ 74 | residual = x 75 | x, attn = self.self_attn( 76 | query=x, 77 | key=x, 78 | value=x, 79 | key_padding_mask=self_attn_padding_mask, 80 | need_weights=False, 81 | attn_mask=self_attn_mask, 82 | ) 83 | x = F.dropout(x, p=self.dropout, training=self.training) 84 | x = residual + x 85 | x = self.self_attn_layer_norm(x) 86 | 87 | residual = x 88 | x = self.activation_fn(self.fc1(x)) 89 | x = F.dropout(x, p=self.activation_dropout, training=self.training) 90 | x = self.fc2(x) 91 | x = F.dropout(x, p=self.dropout, training=self.training) 92 | x = residual + x 93 | x = self.final_layer_norm(x) 94 | return x, attn 95 | -------------------------------------------------------------------------------- /fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 12 | '''unfold T x B x C to T x B x C x K''' 13 | if kernel_size > 1: 14 | T, B, C = x.size() 15 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 16 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 17 | else: 18 | x = x.unsqueeze(3) 19 | return x 20 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from fairseq import registry 12 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer 13 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 14 | from fairseq.optim.bmuf import FairseqBMUF 15 | 16 | 17 | __all__ = [ 18 | 'FairseqOptimizer', 19 | 'FP16Optimizer', 20 | 'MemoryEfficientFP16Optimizer', 21 | ] 22 | 23 | 24 | _build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( 25 | '--optimizer', 26 | base_class=FairseqOptimizer, 27 | default='nag', 28 | ) 29 | 30 | 31 | def build_optimizer(args, params, *extra_args, **extra_kwargs): 32 | params = list(filter(lambda p: p.requires_grad, params)) 33 | return _build_optimizer(args, params, *extra_args, **extra_kwargs) 34 | 35 | 36 | # automatically import any Python files in the optim/ directory 37 | for file in os.listdir(os.path.dirname(__file__)): 38 | if file.endswith('.py') and not file.startswith('_'): 39 | module = file[:file.find('.py')] 40 | importlib.import_module('fairseq.optim.' + module) 41 | -------------------------------------------------------------------------------- /fairseq/optim/adadelta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adadelta') 14 | class Adadelta(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add optimizer-specific arguments to the parser.""" 22 | # fmt: off 23 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', 24 | help='coefficient used for computing a running average of squared gradients') 25 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', 26 | help='term added to the denominator to improve numerical stability') 27 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 28 | help='weight decay') 29 | parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') 30 | # fmt: on 31 | 32 | @property 33 | def optimizer_config(self): 34 | """ 35 | Return a kwarg dictionary that will be used to override optimizer 36 | args stored in checkpoints. This allows us to load a checkpoint and 37 | resume training using a different set of optimizer args, e.g., with a 38 | different learning rate. 39 | """ 40 | return { 41 | 'lr': self.args.lr[0], 42 | 'rho': self.args.adadelta_rho, 43 | 'eps': self.args.adadelta_eps, 44 | 'weight_decay': self.args.weight_decay, 45 | } 46 | -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adagrad') 14 | class Adagrad(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add optimizer-specific arguments to the parser.""" 22 | # fmt: off 23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 24 | help='weight decay') 25 | # fmt: on 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | 'lr': self.args.lr[0], 37 | 'weight_decay': self.args.weight_decay, 38 | } 39 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from fairseq import registry 12 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler 13 | 14 | 15 | build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( 16 | '--lr-scheduler', 17 | base_class=FairseqLRScheduler, 18 | default='fixed', 19 | ) 20 | 21 | # automatically import any Python files in the optim/lr_scheduler/ directory 22 | for file in os.listdir(os.path.dirname(__file__)): 23 | if file.endswith('.py') and not file.startswith('_'): 24 | module = file[:file.find('.py')] 25 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 26 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .. import FairseqOptimizer 9 | 10 | 11 | class FairseqLRScheduler(object): 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__() 15 | if not isinstance(optimizer, FairseqOptimizer): 16 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 17 | self.args = args 18 | self.optimizer = optimizer 19 | self.best = None 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add arguments to the parser for this LR scheduler.""" 24 | pass 25 | 26 | def state_dict(self): 27 | """Return the LR scheduler state dict.""" 28 | return {'best': self.best} 29 | 30 | def load_state_dict(self, state_dict): 31 | """Load an LR scheduler state dict.""" 32 | self.best = state_dict['best'] 33 | 34 | def step(self, epoch, val_loss=None): 35 | """Update the learning rate at the end of the given epoch.""" 36 | if val_loss is not None: 37 | if self.best is None: 38 | self.best = val_loss 39 | else: 40 | self.best = min(self.best, val_loss) 41 | 42 | def step_update(self, num_updates): 43 | """Update the learning rate after each update.""" 44 | return self.optimizer.get_lr() 45 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('fixed') 12 | class FixedSchedule(FairseqLRScheduler): 13 | """Decay the LR on a fixed schedule.""" 14 | 15 | def __init__(self, args, optimizer): 16 | super().__init__(args, optimizer) 17 | 18 | # set defaults 19 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 20 | 21 | self.lr = args.lr[0] 22 | if args.warmup_updates > 0: 23 | self.warmup_factor = 1. / args.warmup_updates 24 | else: 25 | self.warmup_factor = 1 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | """Add arguments to the parser for this LR scheduler.""" 30 | # fmt: off 31 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 32 | help='force annealing at specified epoch') 33 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 34 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)') 35 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 36 | help='warmup the learning rate linearly for the first N updates') 37 | # fmt: on 38 | 39 | def get_next_lr(self, epoch): 40 | lrs = self.args.lr 41 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 42 | # use fixed LR schedule 43 | next_lr = lrs[min(epoch, len(lrs) - 1)] 44 | else: 45 | # annneal based on lr_shrink 46 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 47 | return next_lr 48 | 49 | def step(self, epoch, val_loss=None): 50 | """Update the learning rate at the end of the given epoch.""" 51 | super().step(epoch, val_loss) 52 | self.lr = self.get_next_lr(epoch) 53 | self.optimizer.set_lr(self.warmup_factor * self.lr) 54 | return self.optimizer.get_lr() 55 | 56 | def step_update(self, num_updates): 57 | """Update the learning rate after each update.""" 58 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 59 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 60 | self.optimizer.set_lr(self.warmup_factor * self.lr) 61 | return self.optimizer.get_lr() 62 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/inverse_square_root_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('inverse_sqrt') 12 | class InverseSquareRootSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (``--warmup-init-lr``) until the configured 17 | learning rate (``--lr``). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup:: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup:: 26 | 27 | decay_factor = args.lr * sqrt(args.warmup_updates) 28 | lr = decay_factor / sqrt(update_num) 29 | """ 30 | 31 | def __init__(self, args, optimizer): 32 | super().__init__(args, optimizer) 33 | if len(args.lr) > 1: 34 | raise ValueError( 35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 36 | ' Consider --lr-scheduler=fixed instead.' 37 | ) 38 | warmup_end_lr = args.lr[0] 39 | if args.warmup_init_lr < 0: 40 | args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr 41 | 42 | # linearly warmup for the first args.warmup_updates 43 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 44 | 45 | # then, decay prop. to the inverse square root of the update number 46 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 47 | 48 | # initial learning rate 49 | self.lr = args.warmup_init_lr 50 | self.optimizer.set_lr(self.lr) 51 | 52 | @staticmethod 53 | def add_args(parser): 54 | """Add arguments to the parser for this LR scheduler.""" 55 | # fmt: off 56 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 57 | help='warmup the learning rate linearly for the first N updates') 58 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 59 | help='initial learning rate during warmup phase; default is args.lr') 60 | # fmt: on 61 | 62 | def step(self, epoch, val_loss=None): 63 | """Update the learning rate at the end of the given epoch.""" 64 | super().step(epoch, val_loss) 65 | # we don't change the learning rate at epoch boundaries 66 | return self.optimizer.get_lr() 67 | 68 | def step_update(self, num_updates): 69 | """Update the learning rate after each update.""" 70 | if num_updates < self.args.warmup_updates: 71 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 72 | else: 73 | self.lr = self.decay_factor * num_updates**-0.5 74 | self.optimizer.set_lr(self.lr) 75 | return self.lr 76 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/polynomial_decay_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('polynomial_decay') 12 | class PolynomialDecaySchedule(FairseqLRScheduler): 13 | """Decay the LR on a fixed schedule.""" 14 | 15 | def __init__(self, args, optimizer): 16 | super().__init__(args, optimizer) 17 | 18 | # set defaults 19 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 20 | 21 | self.lr = args.lr[0] 22 | if args.warmup_updates > 0: 23 | self.warmup_factor = 1. / args.warmup_updates 24 | else: 25 | self.warmup_factor = 1 26 | self.end_learning_rate = args.end_learning_rate 27 | self.total_num_update = args.total_num_update 28 | self.power = args.power 29 | self.optimizer.set_lr(self.warmup_factor * self.lr) 30 | 31 | @staticmethod 32 | def add_args(parser): 33 | """Add arguments to the parser for this LR scheduler.""" 34 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 35 | help='force annealing at specified epoch') 36 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 37 | help='warmup the learning rate linearly for the first N updates') 38 | parser.add_argument('--end-learning-rate', default=0.0, type=float) 39 | parser.add_argument('--power', default=1.0, type=float) 40 | parser.add_argument('--total-num-update', default=1000000, type=int) 41 | 42 | def get_next_lr(self, epoch): 43 | lrs = self.args.lr 44 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 45 | # use fixed LR schedule 46 | next_lr = lrs[min(epoch, len(lrs) - 1)] 47 | else: 48 | # annneal based on lr_shrink 49 | next_lr = self.optimizer.get_lr() 50 | return next_lr 51 | 52 | def step(self, epoch, val_loss=None): 53 | """Update the learning rate at the end of the given epoch.""" 54 | super().step(epoch, val_loss) 55 | self.lr = self.get_next_lr(epoch) 56 | self.optimizer.set_lr(self.warmup_factor * self.lr) 57 | return self.optimizer.get_lr() 58 | 59 | def step_update(self, num_updates): 60 | """Update the learning rate after each update.""" 61 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 62 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 63 | lr = self.warmup_factor * self.lr 64 | elif num_updates >= self.total_num_update: 65 | lr = self.end_learning_rate 66 | else: 67 | warmup = self.args.warmup_updates 68 | lr_range = self.lr - self.end_learning_rate 69 | pct_remaining = 1 - (num_updates - warmup) / (self.total_num_update - warmup) 70 | lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate 71 | self.optimizer.set_lr(lr) 72 | return self.optimizer.get_lr() 73 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim.lr_scheduler 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('reduce_lr_on_plateau') 14 | class ReduceLROnPlateau(FairseqLRScheduler): 15 | """Decay the LR by a factor every time the validation loss plateaus.""" 16 | 17 | def __init__(self, args, optimizer): 18 | super().__init__(args, optimizer) 19 | if len(args.lr) > 1: 20 | raise ValueError( 21 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' 22 | ' Consider --lr-scheduler=fixed instead.' 23 | ) 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer.optimizer, patience=0, factor=args.lr_shrink, 26 | threshold=args.lr_threshold) 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | """Add arguments to the parser for this LR scheduler.""" 31 | # fmt: off 32 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 33 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)') 34 | parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', 35 | help='Threshold for measuring the new optimum, \ 36 | to only focus on significant changes') 37 | # fmt: on 38 | 39 | def state_dict(self): 40 | """Return the LR scheduler state dict.""" 41 | return { 42 | 'best': self.lr_scheduler.best, 43 | 'last_epoch': self.lr_scheduler.last_epoch, 44 | } 45 | 46 | def load_state_dict(self, state_dict): 47 | """Load an LR scheduler state dict.""" 48 | self.lr_scheduler.best = state_dict['best'] 49 | if 'last_epoch' in state_dict: 50 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 51 | 52 | def step(self, epoch, val_loss=None): 53 | """Update the learning rate at the end of the given epoch.""" 54 | if val_loss is not None: 55 | self.lr_scheduler.step(val_loss, epoch) 56 | else: 57 | self.lr_scheduler.last_epoch = epoch 58 | return self.optimizer.get_lr() 59 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/triangular_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('triangular') 14 | class TriangularSchedule(FairseqLRScheduler): 15 | """Assign LR based on a triangular cyclical schedule. 16 | 17 | See https://arxiv.org/pdf/1506.01186.pdf for details. 18 | """ 19 | 20 | def __init__(self, args, optimizer): 21 | super().__init__(args, optimizer) 22 | if len(args.lr) > 1: 23 | raise ValueError( 24 | 'Cannot use a fixed learning rate schedule with triangular.' 25 | ' Consider --lr-scheduler=fixed instead.' 26 | ) 27 | 28 | lr = args.lr[0] 29 | 30 | assert args.max_lr > lr, 'max_lr must be more than lr' 31 | self.min_lr = lr 32 | self.max_lr = args.max_lr 33 | self.stepsize = args.lr_period_updates // 2 34 | self.lr_shrink = args.lr_shrink 35 | self.shrink_min = args.shrink_min 36 | 37 | # initial learning rate 38 | self.lr = self.min_lr 39 | self.optimizer.set_lr(self.lr) 40 | 41 | @staticmethod 42 | def add_args(parser): 43 | """Add arguments to the parser for this LR scheduler.""" 44 | # fmt: off 45 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 46 | help='max learning rate, must be more than args.lr') 47 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 48 | help='initial number of updates per period (cycle length)') 49 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 50 | help='shrink factor for annealing') 51 | parser.add_argument('--shrink-min', action='store_true', 52 | help='if set, also shrinks min lr') 53 | # fmt: on 54 | 55 | def step(self, epoch, val_loss=None): 56 | """Update the learning rate at the end of the given epoch.""" 57 | super().step(epoch, val_loss) 58 | # we don't change the learning rate at epoch boundaries 59 | return self.optimizer.get_lr() 60 | 61 | def step_update(self, num_updates): 62 | """Update the learning rate after each update.""" 63 | cycle = math.floor(num_updates / (2 * self.stepsize)) 64 | 65 | lr_shrink = self.lr_shrink ** cycle 66 | max_lr = self.max_lr * lr_shrink 67 | if self.shrink_min: 68 | min_lr = self.min_lr * lr_shrink 69 | else: 70 | min_lr = self.min_lr 71 | 72 | x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) 73 | self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) 74 | 75 | self.optimizer.set_lr(self.lr) 76 | return self.lr 77 | -------------------------------------------------------------------------------- /fairseq/optim/nag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer, required 10 | 11 | from . import FairseqOptimizer, register_optimizer 12 | 13 | 14 | @register_optimizer('nag') 15 | class FairseqNAG(FairseqOptimizer): 16 | def __init__(self, args, params): 17 | super().__init__(args, params) 18 | self._optimizer = NAG(params, **self.optimizer_config) 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | """Add optimizer-specific arguments to the parser.""" 23 | # fmt: off 24 | parser.add_argument('--momentum', default=0.99, type=float, metavar='M', 25 | help='momentum factor') 26 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 27 | help='weight decay') 28 | # fmt: on 29 | 30 | @property 31 | def optimizer_config(self): 32 | """ 33 | Return a kwarg dictionary that will be used to override optimizer 34 | args stored in checkpoints. This allows us to load a checkpoint and 35 | resume training using a different set of optimizer args, e.g., with a 36 | different learning rate. 37 | """ 38 | return { 39 | 'lr': self.args.lr[0], 40 | 'momentum': self.args.momentum, 41 | 'weight_decay': self.args.weight_decay, 42 | } 43 | 44 | 45 | class NAG(Optimizer): 46 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 47 | defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) 48 | super(NAG, self).__init__(params, defaults) 49 | 50 | @property 51 | def supports_memory_efficient_fp16(self): 52 | return True 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | weight_decay = group['weight_decay'] 67 | momentum = group['momentum'] 68 | lr = group['lr'] 69 | lr_old = group.get('lr_old', lr) 70 | lr_correct = lr / lr_old 71 | 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | 76 | p_data_fp32 = p.data.float() 77 | 78 | d_p = p.grad.data.float() 79 | param_state = self.state[p] 80 | if 'momentum_buffer' not in param_state: 81 | param_state['momentum_buffer'] = torch.zeros_like(d_p) 82 | else: 83 | param_state['momentum_buffer'] = param_state['momentum_buffer'].type_as(d_p) 84 | 85 | buf = param_state['momentum_buffer'] 86 | 87 | if weight_decay != 0: 88 | p_data_fp32.mul_(1 - lr * weight_decay) 89 | p_data_fp32.add_(momentum * momentum * lr_correct, buf) 90 | p_data_fp32.add_(-(1 + momentum) * lr, d_p) 91 | 92 | buf.mul_(momentum * lr_correct).add_(-lr, d_p) 93 | 94 | p.data.copy_(p_data_fp32) 95 | 96 | group['lr_old'] = lr 97 | 98 | return loss 99 | -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('sgd') 14 | class SGD(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add optimizer-specific arguments to the parser.""" 22 | # fmt: off 23 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M', 24 | help='momentum factor') 25 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 26 | help='weight decay') 27 | # fmt: on 28 | 29 | @property 30 | def optimizer_config(self): 31 | """ 32 | Return a kwarg dictionary that will be used to override optimizer 33 | args stored in checkpoints. This allows us to load a checkpoint and 34 | resume training using a different set of optimizer args, e.g., with a 35 | different learning rate. 36 | """ 37 | return { 38 | 'lr': self.args.lr[0], 39 | 'momentum': self.args.momentum, 40 | 'weight_decay': self.args.weight_decay, 41 | } 42 | -------------------------------------------------------------------------------- /fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import multiprocessing 9 | import os 10 | import pdb 11 | import sys 12 | 13 | 14 | __all__ = ['set_trace'] 15 | 16 | 17 | _stdin = [None] 18 | _stdin_lock = multiprocessing.Lock() 19 | try: 20 | _stdin_fd = sys.stdin.fileno() 21 | except Exception: 22 | _stdin_fd = None 23 | 24 | 25 | class MultiprocessingPdb(pdb.Pdb): 26 | """A Pdb wrapper that works in a multiprocessing environment. 27 | 28 | Usage: `from fairseq import pdb; pdb.set_trace()` 29 | """ 30 | 31 | def __init__(self): 32 | pdb.Pdb.__init__(self, nosigint=True) 33 | 34 | def _cmdloop(self): 35 | stdin_bak = sys.stdin 36 | with _stdin_lock: 37 | try: 38 | if _stdin_fd is not None: 39 | if not _stdin[0]: 40 | _stdin[0] = os.fdopen(_stdin_fd) 41 | sys.stdin = _stdin[0] 42 | self.cmdloop() 43 | finally: 44 | sys.stdin = stdin_bak 45 | 46 | 47 | def set_trace(): 48 | pdb = MultiprocessingPdb() 49 | pdb.set_trace(sys._getframe().f_back) 50 | -------------------------------------------------------------------------------- /fairseq/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | REGISTRIES = {} 10 | 11 | 12 | def setup_registry( 13 | registry_name: str, 14 | base_class=None, 15 | default=None, 16 | ): 17 | assert registry_name.startswith('--') 18 | registry_name = registry_name[2:].replace('-', '_') 19 | 20 | REGISTRY = {} 21 | REGISTRY_CLASS_NAMES = set() 22 | 23 | # maintain a registry of all registries 24 | if registry_name in REGISTRIES: 25 | raise ValueError('Canot setup duplicate registry: {}'.format(registry_name)) 26 | REGISTRIES[registry_name] = { 27 | 'registry': REGISTRY, 28 | 'default': default, 29 | } 30 | 31 | def build_x(args, *extra_args, **extra_kwargs): 32 | choice = getattr(args, registry_name, None) 33 | if choice is None: 34 | return None 35 | cls = REGISTRY[choice] 36 | if hasattr(cls, 'build_' + registry_name): 37 | builder = getattr(cls, 'build_' + registry_name) 38 | else: 39 | builder = cls 40 | return builder(args, *extra_args, **extra_kwargs) 41 | 42 | def register_x(name): 43 | 44 | def register_x_cls(cls): 45 | if name in REGISTRY: 46 | raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name)) 47 | if cls.__name__ in REGISTRY_CLASS_NAMES: 48 | raise ValueError( 49 | 'Cannot register {} with duplicate class name ({})'.format( 50 | registry_name, cls.__name__, 51 | ) 52 | ) 53 | if base_class is not None and not issubclass(cls, base_class): 54 | raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__)) 55 | REGISTRY[name] = cls 56 | REGISTRY_CLASS_NAMES.add(cls.__name__) 57 | return cls 58 | 59 | return register_x_cls 60 | 61 | return build_x, register_x, REGISTRY 62 | -------------------------------------------------------------------------------- /fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import importlib 10 | import os 11 | 12 | from .fairseq_task import FairseqTask 13 | 14 | TASK_REGISTRY = {} 15 | TASK_CLASS_NAMES = set() 16 | 17 | 18 | def setup_task(args, **kwargs): 19 | return TASK_REGISTRY[args.task].setup_task(args, **kwargs) 20 | 21 | 22 | def register_task(name): 23 | """ 24 | New tasks can be added to fairseq with the 25 | :func:`~fairseq.tasks.register_task` function decorator. 26 | 27 | For example:: 28 | 29 | @register_task('classification') 30 | class ClassificationTask(FairseqTask): 31 | (...) 32 | 33 | .. note:: 34 | 35 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 36 | interface. 37 | 38 | Please see the 39 | 40 | Args: 41 | name (str): the name of the task 42 | """ 43 | 44 | def register_task_cls(cls): 45 | if name in TASK_REGISTRY: 46 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 47 | if not issubclass(cls, FairseqTask): 48 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 49 | if cls.__name__ in TASK_CLASS_NAMES: 50 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 51 | TASK_REGISTRY[name] = cls 52 | TASK_CLASS_NAMES.add(cls.__name__) 53 | return cls 54 | 55 | return register_task_cls 56 | 57 | 58 | # automatically import any Python files in the tasks/ directory 59 | for file in os.listdir(os.path.dirname(__file__)): 60 | if file.endswith('.py') and not file.startswith('_'): 61 | task_name = file[:file.find('.py')] 62 | importlib.import_module('fairseq.tasks.' + task_name) 63 | 64 | # expose `task_parser` for sphinx 65 | if task_name in TASK_REGISTRY: 66 | parser = argparse.ArgumentParser(add_help=False) 67 | group_task = parser.add_argument_group('Task name') 68 | # fmt: off 69 | group_task.add_argument('--task', metavar=task_name, 70 | help='Enable this task with: ``--task=' + task_name + '``') 71 | # fmt: on 72 | group_args = parser.add_argument_group('Additional command-line arguments') 73 | TASK_REGISTRY[task_name].add_args(group_args) 74 | globals()[task_name + '_parser'] = parser 75 | 76 | 77 | def get_task(name): 78 | return TASK_REGISTRY[name] 79 | -------------------------------------------------------------------------------- /fairseq/tasks/audio_pretraining.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import os 9 | 10 | from fairseq.data import RawAudioDataset 11 | from . import FairseqTask, register_task 12 | 13 | 14 | @register_task('audio_pretraining') 15 | class AudioPretrainingTask(FairseqTask): 16 | """ 17 | 18 | """ 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | """Add task-specific arguments to the parser.""" 23 | parser.add_argument('data', help='path to data directory') 24 | parser.add_argument('--sample-rate', default=16000, type=int, 25 | help='target sample rate. audio files will be up/down sampled to this rate') 26 | parser.add_argument('--max-sample-size', default=None, type=int, 27 | help='max sample size to crop to for batching. default = min sample length') 28 | parser.add_argument('--min-sample-size', default=None, type=int, 29 | help='min sample size to crop to for batching. default = same as --max-sample-size') 30 | 31 | def __init__(self, args): 32 | super().__init__(args) 33 | 34 | @classmethod 35 | def setup_task(cls, args, **kwargs): 36 | """Setup the task (e.g., load dictionaries). 37 | 38 | Args: 39 | args (argparse.Namespace): parsed command-line arguments 40 | """ 41 | return cls(args) 42 | 43 | def load_dataset(self, split, **kwargs): 44 | """Load a given dataset split. 45 | 46 | Args: 47 | split (str): name of the split (e.g., train, valid, test) 48 | """ 49 | 50 | manifest = os.path.join(self.args.data, '{}.tsv'.format(split)) 51 | self.datasets[split] = RawAudioDataset(manifest, 52 | sample_rate=self.args.sample_rate, 53 | max_sample_size=self.args.max_sample_size, 54 | min_sample_size=self.args.min_sample_size) 55 | 56 | @property 57 | def target_dictionary(self): 58 | """Return the :class:`~fairseq.data.Dictionary` for the language 59 | model.""" 60 | return None -------------------------------------------------------------------------------- /fairseq/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.data.masked_lm_dictionary import MaskedLMDictionary 9 | from fairseq.tasks.translation import TranslationTask 10 | 11 | from . import register_task 12 | 13 | 14 | @register_task("translation_from_pretrained_xlm") 15 | class TranslationFromPretrainedXLMTask(TranslationTask): 16 | """ 17 | Same as TranslationTask except use the MaskedLMDictionary class so that 18 | we can load data that was binarized with the MaskedLMDictionary class. 19 | 20 | This task should be used for the entire training pipeline when we want to 21 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 22 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 23 | of that trained model. 24 | """ 25 | 26 | @classmethod 27 | def load_dictionary(cls, filename): 28 | """Load the masked LM dictionary from the filename 29 | 30 | Args: 31 | filename (str): the filename 32 | """ 33 | return MaskedLMDictionary.load(filename) 34 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import re 9 | 10 | SPACE_NORMALIZER = re.compile(r"\s+") 11 | 12 | 13 | def tokenize_line(line): 14 | line = SPACE_NORMALIZER.sub(" ", line) 15 | line = line.strip() 16 | return line.split() 17 | -------------------------------------------------------------------------------- /fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/fairseq/7efde2261f78e9e8d20e637e252bdd9977ec9290/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /fairseq_cli/eval_lm.py: -------------------------------------------------------------------------------- 1 | ../eval_lm.py -------------------------------------------------------------------------------- /fairseq_cli/generate.py: -------------------------------------------------------------------------------- 1 | ../generate.py -------------------------------------------------------------------------------- /fairseq_cli/interactive.py: -------------------------------------------------------------------------------- 1 | ../interactive.py -------------------------------------------------------------------------------- /fairseq_cli/preprocess.py: -------------------------------------------------------------------------------- 1 | ../preprocess.py -------------------------------------------------------------------------------- /fairseq_cli/score.py: -------------------------------------------------------------------------------- 1 | ../score.py -------------------------------------------------------------------------------- /fairseq_cli/setup.py: -------------------------------------------------------------------------------- 1 | ../setup.py -------------------------------------------------------------------------------- /fairseq_cli/train.py: -------------------------------------------------------------------------------- 1 | ../train.py -------------------------------------------------------------------------------- /fairseq_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/fairseq/7efde2261f78e9e8d20e637e252bdd9977ec9290/fairseq_logo.png -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.models import MODEL_REGISTRY 9 | 10 | 11 | dependencies = [ 12 | 'regex', 13 | 'requests', 14 | 'sacremoses', 15 | 'sentencepiece', 16 | 'subword_nmt', 17 | 'torch', 18 | ] 19 | 20 | 21 | for model, cls in MODEL_REGISTRY.items(): 22 | globals()[model] = cls.from_pretrained 23 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | BLEU scoring of generated translations against reference translations. 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | 16 | from fairseq import bleu 17 | from fairseq.data import dictionary 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') 22 | # fmt: off 23 | parser.add_argument('-s', '--sys', default='-', help='system output') 24 | parser.add_argument('-r', '--ref', required=True, help='references') 25 | parser.add_argument('-o', '--order', default=4, metavar='N', 26 | type=int, help='consider ngrams up to this order') 27 | parser.add_argument('--ignore-case', action='store_true', 28 | help='case-insensitive scoring') 29 | parser.add_argument('--sacrebleu', action='store_true', 30 | help='score with sacrebleu') 31 | parser.add_argument('--sentence-bleu', action='store_true', 32 | help='report sentence-level BLEUs (i.e., with +1 smoothing)') 33 | # fmt: on 34 | return parser 35 | 36 | 37 | def main(): 38 | parser = get_parser() 39 | args = parser.parse_args() 40 | print(args) 41 | 42 | assert args.sys == '-' or os.path.exists(args.sys), \ 43 | "System output file {} does not exist".format(args.sys) 44 | assert os.path.exists(args.ref), \ 45 | "Reference file {} does not exist".format(args.ref) 46 | 47 | dict = dictionary.Dictionary() 48 | 49 | def readlines(fd): 50 | for line in fd.readlines(): 51 | if args.ignore_case: 52 | yield line.lower() 53 | else: 54 | yield line 55 | 56 | if args.sacrebleu: 57 | import sacrebleu 58 | 59 | def score(fdsys): 60 | with open(args.ref) as fdref: 61 | print(sacrebleu.corpus_bleu(fdsys, [fdref])) 62 | elif args.sentence_bleu: 63 | def score(fdsys): 64 | with open(args.ref) as fdref: 65 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 66 | for i, (sys_tok, ref_tok) in enumerate(zip(readlines(fdsys), readlines(fdref))): 67 | scorer.reset(one_init=True) 68 | sys_tok = dict.encode_line(sys_tok) 69 | ref_tok = dict.encode_line(ref_tok) 70 | scorer.add(ref_tok, sys_tok) 71 | print(i, scorer.result_string(args.order)) 72 | else: 73 | def score(fdsys): 74 | with open(args.ref) as fdref: 75 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 76 | for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): 77 | sys_tok = dict.encode_line(sys_tok) 78 | ref_tok = dict.encode_line(ref_tok) 79 | scorer.add(ref_tok, sys_tok) 80 | print(scorer.result_string(args.order)) 81 | 82 | if args.sys == '-': 83 | score(sys.stdin) 84 | else: 85 | with open(args.sys, 'r') as f: 86 | score(f) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/fairseq/7efde2261f78e9e8d20e637e252bdd9977ec9290/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compare_namespaces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script to compare two argparse.Namespace objects.""" 3 | 4 | from argparse import Namespace 5 | 6 | 7 | def main(): 8 | 9 | ns1 = eval(input('Namespace 1: ')) 10 | ns2 = eval(input('Namespace 2: ')) 11 | 12 | def keys(ns): 13 | ks = set() 14 | for k in dir(ns): 15 | if not k.startswith('_'): 16 | ks.add(k) 17 | return ks 18 | 19 | k1 = keys(ns1) 20 | k2 = keys(ns2) 21 | 22 | def print_keys(ks, ns1, ns2=None): 23 | for k in ks: 24 | if ns2 is None: 25 | print('{}\t{}'.format(k, getattr(ns1, k, None))) 26 | else: 27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None))) 28 | 29 | print('Keys unique to namespace 1:') 30 | print_keys(k1 - k2, ns1) 31 | print() 32 | 33 | print('Keys unique to namespace 2:') 34 | print_keys(k2 - k1, ns2) 35 | print() 36 | 37 | print('Overlapping keys with different values:') 38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')] 39 | print_keys(ks, ns1, ns2) 40 | print() 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_dictionary.lua <dict.th7> 9 | require 'fairseq' 10 | require 'torch' 11 | require 'paths' 12 | 13 | if #arg < 1 then 14 | print('usage: convert_dictionary.lua <dict.th7>') 15 | os.exit(1) 16 | end 17 | if not paths.filep(arg[1]) then 18 | print('error: file does not exit: ' .. arg[1]) 19 | os.exit(1) 20 | end 21 | 22 | dict = torch.load(arg[1]) 23 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 24 | assert(dst:match('.txt$')) 25 | 26 | f = io.open(dst, 'w') 27 | for idx, symbol in ipairs(dict.index_to_symbol) do 28 | if idx > dict.cutoff then 29 | break 30 | end 31 | f:write(symbol) 32 | f:write(' ') 33 | f:write(dict.index_to_freq[idx]) 34 | f:write('\n') 35 | end 36 | f:close() 37 | -------------------------------------------------------------------------------- /scripts/count_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | Count the number of documents and average number of lines and tokens per 10 | document in a large file. Documents should be separated by a single empty line. 11 | """ 12 | 13 | import argparse 14 | import gzip 15 | import random 16 | import sys 17 | 18 | import numpy as np 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('input') 24 | parser.add_argument('--gzip', action='store_true') 25 | args = parser.parse_args() 26 | 27 | def gopen(): 28 | if args.gzip: 29 | return gzip.open(args.input, 'r') 30 | else: 31 | return open(args.input, 'r', encoding='utf-8') 32 | 33 | num_lines = [] 34 | num_toks = [] 35 | with gopen() as h: 36 | num_docs = 1 37 | num_lines_in_doc = 0 38 | num_toks_in_doc = 0 39 | for i, line in enumerate(h): 40 | if len(line.strip()) == 0: # empty line indicates new document 41 | num_docs += 1 42 | num_lines.append(num_lines_in_doc) 43 | num_toks.append(num_toks_in_doc) 44 | num_lines_in_doc = 0 45 | num_toks_in_doc = 0 46 | else: 47 | num_lines_in_doc += 1 48 | num_toks_in_doc += len(line.rstrip().split()) 49 | if i % 1000000 == 0: 50 | print(i, file=sys.stderr, end="", flush=True) 51 | elif i % 100000 == 0: 52 | print(".", file=sys.stderr, end="", flush=True) 53 | print(file=sys.stderr, flush=True) 54 | 55 | print("found {} docs".format(num_docs)) 56 | print("average num lines per doc: {}".format(np.mean(num_lines))) 57 | print("average num toks per doc: {}".format(np.mean(num_toks))) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import argparse 10 | 11 | from fairseq.data import Dictionary 12 | from fairseq.data import indexed_dataset 13 | 14 | 15 | def get_parser(): 16 | parser = argparse.ArgumentParser( 17 | description='writes text from binarized file to stdout') 18 | # fmt: off 19 | parser.add_argument('--dataset-impl', help='dataset implementation', 20 | choices=['raw', 'lazy', 'cached', 'mmap'], default='lazy') 21 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) 22 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 23 | # fmt: on 24 | 25 | return parser 26 | 27 | 28 | def main(): 29 | parser = get_parser() 30 | args = parser.parse_args() 31 | 32 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None 33 | dataset = indexed_dataset.make_dataset(args.input, impl=args.dataset_impl, 34 | fix_lua_indexing=True, dictionary=dictionary) 35 | 36 | for tensor_line in dataset: 37 | if dictionary is None: 38 | line = ' '.join([str(int(x)) for x in tensor_line]) 39 | else: 40 | line = dictionary.string(tensor_line) 41 | 42 | print(line) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/sacrebleu_pregen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | echo 'Cloning Moses github repository (for tokenization scripts)...' 15 | git clone https://github.com/moses-smt/mosesdecoder.git 16 | 17 | SCRIPTS=mosesdecoder/scripts 18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | perl $DETOKENIZER -l $TGTLANG \ 25 | | sed "s/ - /-/g" \ 26 | > $GEN.sorted.detok 27 | 28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 29 | -------------------------------------------------------------------------------- /scripts/shard_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | Split a large file into shards while respecting document boundaries. Documents 10 | should be separated by a single empty line. 11 | """ 12 | 13 | import argparse 14 | import contextlib 15 | import random 16 | import sys 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('input') 22 | parser.add_argument('--num-shards', type=int) 23 | args = parser.parse_args() 24 | 25 | assert args.num_shards is not None and args.num_shards > 1 26 | 27 | with open(args.input, 'r', encoding='utf-8') as h: 28 | with contextlib.ExitStack() as stack: 29 | outputs = [ 30 | stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8")) 31 | for i in range(args.num_shards) 32 | ] 33 | 34 | doc = [] 35 | first_doc = [True]*args.num_shards 36 | def output_doc(i): 37 | if not first_doc[i]: 38 | outputs[i].write("\n") 39 | first_doc[i] = False 40 | for line in doc: 41 | outputs[i].write(line) 42 | doc.clear() 43 | 44 | num_docs = 0 45 | for line in h: 46 | if line.strip() == "": # empty line indicates new document 47 | output_doc(num_docs % args.num_shards) 48 | num_docs += 1 49 | else: 50 | doc.append(line) 51 | output_doc(num_docs % args.num_shards) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /scripts/split_train_valid_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | Split a large file into a train and valid set while respecting document 10 | boundaries. Documents should be separated by a single empty line. 11 | """ 12 | 13 | import argparse 14 | import random 15 | import sys 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('input') 21 | parser.add_argument('sample_output', help='train output file') 22 | parser.add_argument('remainder_output', help='valid output file') 23 | parser.add_argument('-k', type=int, help="remainder size") 24 | args = parser.parse_args() 25 | 26 | assert args.k is not None 27 | 28 | sample = [] 29 | remainder = [] 30 | num_docs = [0] 31 | 32 | def update_sample(doc): 33 | if len(sample) < args.k: 34 | sample.append(doc.copy()) 35 | else: 36 | i = num_docs[0] 37 | j = random.randrange(i + 1) 38 | if j < args.k: 39 | remainder.append(sample[j]) 40 | sample[j] = doc.copy() 41 | else: 42 | remainder.append(doc.copy()) 43 | num_docs[0] += 1 44 | doc.clear() 45 | 46 | with open(args.input, 'r', encoding='utf-8') as h: 47 | doc = [] 48 | for i, line in enumerate(h): 49 | if line.strip() == "": # empty line indicates new document 50 | update_sample(doc) 51 | else: 52 | doc.append(line) 53 | if i % 1000000 == 0: 54 | print(i, file=sys.stderr, end="", flush=True) 55 | elif i % 100000 == 0: 56 | print(".", file=sys.stderr, end="", flush=True) 57 | if len(doc) > 0: 58 | update_sample(doc) 59 | print(file=sys.stderr, flush=True) 60 | 61 | assert len(sample) == args.k 62 | 63 | with open(args.sample_output, 'w', encoding='utf-8') as out: 64 | first = True 65 | for doc in sample: 66 | if not first: 67 | out.write("\n") 68 | first = False 69 | for line in doc: 70 | out.write(line) 71 | 72 | with open(args.remainder_output, 'w', encoding='utf-8') as out: 73 | first = True 74 | for doc in remainder: 75 | if not first: 76 | out.write("\n") 77 | first = False 78 | for line in doc: 79 | out.write(line) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", required=True, 18 | help="sentencepiece model to use for decoding") 19 | parser.add_argument("--input", required=True, help="input file to decode") 20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 21 | args = parser.parse_args() 22 | 23 | sp = spm.SentencePieceProcessor() 24 | sp.Load(args.model) 25 | 26 | if args.input_format == "piece": 27 | def decode(l): 28 | return "".join(sp.DecodePieces(l)) 29 | elif args.input_format == "id": 30 | def decode(l): 31 | return "".join(sp.DecodeIds(l)) 32 | else: 33 | raise NotImplementedError 34 | 35 | def tok2int(tok): 36 | # remap reference-side <unk> (represented as <<unk>>) to 0 37 | return int(tok) if tok != "<<unk>>" else 0 38 | 39 | with open(args.input, "r", encoding="utf-8") as h: 40 | for line in h: 41 | print(decode(list(map(tok2int, line.rstrip().split())))) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/spm_encode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | import contextlib 12 | import sys 13 | 14 | import sentencepiece as spm 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model", required=True, 20 | help="sentencepiece model to use for encoding") 21 | parser.add_argument("--inputs", nargs="+", default=['-'], 22 | help="input files to filter/encode") 23 | parser.add_argument("--outputs", nargs="+", default=['-'], 24 | help="path to save encoded outputs") 25 | parser.add_argument("--output_format", choices=["piece", "id"], default="piece") 26 | parser.add_argument("--min-len", type=int, metavar="N", 27 | help="filter sentence pairs with fewer than N tokens") 28 | parser.add_argument("--max-len", type=int, metavar="N", 29 | help="filter sentence pairs with more than N tokens") 30 | args = parser.parse_args() 31 | 32 | assert len(args.inputs) == len(args.outputs), \ 33 | "number of input and output paths should match" 34 | 35 | sp = spm.SentencePieceProcessor() 36 | sp.Load(args.model) 37 | 38 | if args.output_format == "piece": 39 | def encode(l): 40 | return sp.EncodeAsPieces(l) 41 | elif args.output_format == "id": 42 | def encode(l): 43 | return list(map(str, sp.EncodeAsIds(l))) 44 | else: 45 | raise NotImplementedError 46 | 47 | if args.min_len is not None or args.max_len is not None: 48 | def valid(line): 49 | return ( 50 | (args.min_len is None or len(line) >= args.min_len) 51 | and (args.max_len is None or len(line) <= args.max_len) 52 | ) 53 | else: 54 | def valid(lines): 55 | return True 56 | 57 | with contextlib.ExitStack() as stack: 58 | inputs = [ 59 | stack.enter_context(open(input, "r", encoding="utf-8")) \ 60 | if input != "-" else sys.stdin 61 | for input in args.inputs 62 | ] 63 | outputs = [ 64 | stack.enter_context(open(output, "w", encoding="utf-8")) \ 65 | if output != "-" else sys.stdout 66 | for output in args.outputs 67 | ] 68 | 69 | stats = { 70 | "num_empty": 0, 71 | "num_filtered": 0, 72 | } 73 | 74 | def encode_line(line): 75 | line = line.strip() 76 | if len(line) > 0: 77 | line = encode(line) 78 | if valid(line): 79 | return line 80 | else: 81 | stats["num_filtered"] += 1 82 | else: 83 | stats["num_empty"] += 1 84 | return None 85 | 86 | for i, lines in enumerate(zip(*inputs), start=1): 87 | enc_lines = list(map(encode_line, lines)) 88 | if not any(enc_line is None for enc_line in enc_lines): 89 | for enc_line, output_h in zip(enc_lines, outputs): 90 | print(" ".join(enc_line), file=output_h) 91 | if i % 10000 == 0: 92 | print("processed {} lines".format(i), file=sys.stderr) 93 | 94 | print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) 95 | print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import shlex 11 | import sys 12 | 13 | import sentencepiece as spm 14 | 15 | 16 | if __name__ == "__main__": 17 | spm.SentencePieceTrainer.Train(" ".join(map(shlex.quote, sys.argv[1:]))) 18 | -------------------------------------------------------------------------------- /scripts/wav2vec_manifest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | Data pre-processing: build vocabularies and binarize training data. 10 | """ 11 | 12 | import argparse 13 | import glob 14 | import os 15 | import soundfile 16 | import random 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index') 22 | parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D', 23 | help='percentage of data to use as validation set (between 0 and 1)') 24 | parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory') 25 | parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for') 26 | parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed') 27 | parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG', 28 | help='if set, path must contain this substring for a file to be included in the manifest') 29 | return parser 30 | 31 | 32 | def main(args): 33 | assert args.valid_percent >= 0 and args.valid_percent <= 1. 34 | 35 | dir_path = os.path.realpath(args.root) 36 | search_path = os.path.join(dir_path, '**/*.' + args.ext) 37 | rand = random.Random(args.seed) 38 | 39 | with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open( 40 | os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f: 41 | print(dir_path, file=train_f) 42 | print(dir_path, file=valid_f) 43 | 44 | for fname in glob.iglob(search_path, recursive=True): 45 | file_path = os.path.realpath(fname) 46 | 47 | if args.path_must_contain and args.path_must_contain not in file_path: 48 | continue 49 | 50 | frames = soundfile.info(fname).frames 51 | dest = train_f if rand.random() > args.valid_percent else valid_f 52 | print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest) 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = get_parser() 57 | args = parser.parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from setuptools import setup, find_packages, Extension 10 | import sys 11 | 12 | 13 | if sys.version_info < (3,): 14 | sys.exit('Sorry, Python3 is required for fairseq.') 15 | 16 | with open('README.md') as f: 17 | readme = f.read() 18 | 19 | if sys.platform == 'darwin': 20 | extra_compile_args = ['-stdlib=libc++'] 21 | else: 22 | extra_compile_args = ['-std=c++11'] 23 | bleu = Extension( 24 | 'fairseq.libbleu', 25 | sources=[ 26 | 'fairseq/clib/libbleu/libbleu.cpp', 27 | 'fairseq/clib/libbleu/module.cpp', 28 | ], 29 | extra_compile_args=extra_compile_args, 30 | ) 31 | 32 | 33 | setup( 34 | name='fairseq', 35 | version='0.7.2', 36 | description='Facebook AI Research Sequence-to-Sequence Toolkit', 37 | url='https://github.com/pytorch/fairseq', 38 | classifiers=[ 39 | 'Intended Audience :: Science/Research', 40 | 'License :: OSI Approved :: BSD License', 41 | 'Programming Language :: Python :: 3.5', 42 | 'Programming Language :: Python :: 3.6', 43 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 44 | ], 45 | long_description=readme, 46 | long_description_content_type='text/markdown', 47 | install_requires=[ 48 | 'cffi', 49 | 'numpy', 50 | 'sacrebleu', 51 | 'torch', 52 | 'tqdm', 53 | ], 54 | packages=find_packages(exclude=['scripts', 'tests']), 55 | ext_modules=[bleu], 56 | test_suite='tests', 57 | entry_points={ 58 | 'console_scripts': [ 59 | 'fairseq-eval-lm = fairseq_cli.eval_lm:cli_main', 60 | 'fairseq-generate = fairseq_cli.generate:cli_main', 61 | 'fairseq-interactive = fairseq_cli.interactive:cli_main', 62 | 'fairseq-preprocess = fairseq_cli.preprocess:cli_main', 63 | 'fairseq-train = fairseq_cli.train:cli_main', 64 | 'fairseq-score = fairseq_cli.score:main', 65 | ], 66 | }, 67 | ) 68 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-tpu/fairseq/7efde2261f78e9e8d20e637e252bdd9977ec9290/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | 11 | from fairseq.data import Dictionary 12 | from fairseq.modules import CharacterTokenEmbedder 13 | 14 | 15 | class TestCharacterTokenEmbedder(unittest.TestCase): 16 | def test_character_token_embedder(self): 17 | vocab = Dictionary() 18 | vocab.add_symbol('hello') 19 | vocab.add_symbol('there') 20 | 21 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2) 22 | 23 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] 24 | max_len = max(len(s) for s in test_sents) 25 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) 26 | for i in range(len(test_sents)): 27 | input[i][0] = vocab.eos() 28 | for j in range(len(test_sents[i])): 29 | input[i][j + 1] = vocab.index(test_sents[i][j]) 30 | input[i][j + 2] = vocab.eos() 31 | embs = embedder(input) 32 | 33 | assert embs.size() == (len(test_sents), max_len + 2, 5) 34 | self.assertAlmostEqual(embs[0][0], embs[1][0]) 35 | self.assertAlmostEqual(embs[0][0], embs[0][-1]) 36 | self.assertAlmostEqual(embs[0][1], embs[2][1]) 37 | self.assertAlmostEqual(embs[0][3], embs[1][1]) 38 | 39 | embs.sum().backward() 40 | assert embedder.char_embeddings.weight.grad is not None 41 | 42 | def assertAlmostEqual(self, t1, t2): 43 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 44 | self.assertLess((t1 - t2).abs().max(), 1e-6) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | from fairseq.data import LanguagePairDataset, TokenBlockDataset 12 | from fairseq.data.concat_dataset import ConcatDataset 13 | from tests.test_train import mock_dict 14 | 15 | 16 | class TestConcatDataset(unittest.TestCase): 17 | def setUp(self): 18 | d = mock_dict() 19 | tokens_1 = torch.LongTensor([1]).view(1, -1) 20 | tokens_ds1 = TokenBlockDataset( 21 | tokens_1, 22 | sizes=[tokens_1.size(-1)], 23 | block_size=1, 24 | pad=0, 25 | eos=1, 26 | include_targets=False, 27 | ) 28 | self.dataset_1 = LanguagePairDataset( 29 | tokens_ds1, tokens_ds1.sizes, d, shuffle=False 30 | ) 31 | tokens_2 = torch.LongTensor([2]).view(1, -1) 32 | tokens_ds2 = TokenBlockDataset( 33 | tokens_2, 34 | sizes=[tokens_2.size(-1)], 35 | block_size=1, 36 | pad=0, 37 | eos=1, 38 | include_targets=False, 39 | ) 40 | self.dataset_2 = LanguagePairDataset( 41 | tokens_ds2, tokens_ds2.sizes, d, shuffle=False 42 | ) 43 | 44 | def test_concat_dataset_basics(self): 45 | d = ConcatDataset( 46 | [self.dataset_1, self.dataset_2] 47 | ) 48 | assert(len(d) == 2) 49 | assert(d[0]['source'][0] == 1) 50 | assert(d[1]['source'][0] == 2) 51 | 52 | d = ConcatDataset( 53 | [self.dataset_1, self.dataset_2], sample_ratios=[1, 2] 54 | ) 55 | assert(len(d) == 3) 56 | assert(d[0]['source'][0] == 1) 57 | assert(d[1]['source'][0] == 2) 58 | assert(d[2]['source'][0] == 2) 59 | 60 | d = ConcatDataset( 61 | [self.dataset_1, self.dataset_2], sample_ratios=[2, 1] 62 | ) 63 | assert(len(d) == 3) 64 | assert(d[0]['source'][0] == 1) 65 | assert(d[1]['source'][0] == 1) 66 | assert(d[2]['source'][0] == 2) 67 | -------------------------------------------------------------------------------- /tests/test_convtbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | from fairseq.modules import ConvTBC 11 | import torch.nn as nn 12 | 13 | 14 | class TestConvTBC(unittest.TestCase): 15 | 16 | def test_convtbc(self): 17 | # ksz, in_channels, out_channels 18 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) 19 | # out_channels, in_channels, ksz 20 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1) 21 | 22 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) 23 | conv_tbc.bias.data.copy_(conv1d.bias.data) 24 | 25 | input_tbc = torch.randn(7, 2, 4, requires_grad=True) 26 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2) 27 | input1d.requires_grad = True 28 | 29 | output_tbc = conv_tbc(input_tbc) 30 | output1d = conv1d(input1d) 31 | 32 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) 33 | 34 | grad_tbc = torch.randn(output_tbc.size()) 35 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() 36 | 37 | output_tbc.backward(grad_tbc) 38 | output1d.backward(grad1d) 39 | 40 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data) 41 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) 42 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data) 43 | 44 | def assertAlmostEqual(self, t1, t2): 45 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 46 | self.assertLess((t1 - t2).abs().max(), 1e-4) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/test_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import tempfile 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.data import Dictionary 14 | 15 | 16 | class TestDictionary(unittest.TestCase): 17 | 18 | def test_finalize(self): 19 | txt = [ 20 | 'A B C D', 21 | 'B C D', 22 | 'C D', 23 | 'D', 24 | ] 25 | ref_ids1 = list(map(torch.IntTensor, [ 26 | [4, 5, 6, 7, 2], 27 | [5, 6, 7, 2], 28 | [6, 7, 2], 29 | [7, 2], 30 | ])) 31 | ref_ids2 = list(map(torch.IntTensor, [ 32 | [7, 6, 5, 4, 2], 33 | [6, 5, 4, 2], 34 | [5, 4, 2], 35 | [4, 2], 36 | ])) 37 | 38 | # build dictionary 39 | d = Dictionary() 40 | for line in txt: 41 | d.encode_line(line, add_if_not_exist=True) 42 | 43 | def get_ids(dictionary): 44 | ids = [] 45 | for line in txt: 46 | ids.append(dictionary.encode_line(line, add_if_not_exist=False)) 47 | return ids 48 | 49 | def assertMatch(ids, ref_ids): 50 | for toks, ref_toks in zip(ids, ref_ids): 51 | self.assertEqual(toks.size(), ref_toks.size()) 52 | self.assertEqual(0, (toks != ref_toks).sum().item()) 53 | 54 | ids = get_ids(d) 55 | assertMatch(ids, ref_ids1) 56 | 57 | # check finalized dictionary 58 | d.finalize() 59 | finalized_ids = get_ids(d) 60 | assertMatch(finalized_ids, ref_ids2) 61 | 62 | # write to disk and reload 63 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict: 64 | d.save(tmp_dict.name) 65 | d = Dictionary.load(tmp_dict.name) 66 | reload_ids = get_ids(d) 67 | assertMatch(reload_ids, ref_ids2) 68 | assertMatch(finalized_ids, reload_ids) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /tests/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | from fairseq.data import iterators 11 | 12 | 13 | class TestIterators(unittest.TestCase): 14 | 15 | def test_counting_iterator(self): 16 | x = list(range(10)) 17 | itr = iterators.CountingIterator(x) 18 | self.assertTrue(itr.has_next()) 19 | self.assertEqual(next(itr), 0) 20 | self.assertEqual(next(itr), 1) 21 | itr.skip(3) 22 | self.assertEqual(next(itr), 5) 23 | itr.skip(3) 24 | self.assertEqual(next(itr), 9) 25 | self.assertFalse(itr.has_next()) 26 | 27 | 28 | if __name__ == '__main__': 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /tests/test_memory_efficient_fp16.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.optim.adam import FairseqAdam 14 | from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer 15 | 16 | 17 | class TestMemoryEfficientFP16(unittest.TestCase): 18 | 19 | def test_load_state_dict(self): 20 | # define simple FP16 model 21 | model = torch.nn.Linear(5, 5).cuda().half() 22 | params = list(model.parameters()) 23 | 24 | # initialize memory efficient FP16 optimizer 25 | optimizer = FairseqAdam( 26 | argparse.Namespace( 27 | lr=[0.00001], 28 | adam_betas='(0.9, 0.999)', 29 | adam_eps=1e-8, 30 | weight_decay=0.0, 31 | ), 32 | params, 33 | ) 34 | me_optimizer = MemoryEfficientFP16Optimizer( 35 | argparse.Namespace( 36 | fp16_init_scale=1, 37 | fp16_scale_window=1, 38 | fp16_scale_tolerance=1, 39 | threshold_loss_scale=1, 40 | ), 41 | params, 42 | optimizer, 43 | ) 44 | 45 | # optimizer state is created in the first step 46 | loss = model(torch.rand(5).cuda().half()).sum() 47 | me_optimizer.backward(loss) 48 | me_optimizer.step() 49 | 50 | # reload state 51 | state = me_optimizer.state_dict() 52 | me_optimizer.load_state_dict(state) 53 | for k, v in me_optimizer.optimizer.state.items(): 54 | self.assertTrue(k.dtype == torch.float16) 55 | for v_i in v.values(): 56 | if torch.is_tensor(v_i): 57 | self.assertTrue(v_i.dtype == torch.float32) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /tests/test_multi_corpus_sampled_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | from collections import OrderedDict 10 | 11 | import numpy as np 12 | import torch 13 | from fairseq.data import LanguagePairDataset, TokenBlockDataset 14 | from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset 15 | from tests.test_train import mock_dict 16 | 17 | 18 | class TestMultiCorpusSampledDataset(unittest.TestCase): 19 | def setUp(self): 20 | d = mock_dict() 21 | tokens_1 = torch.LongTensor([1]).view(1, -1) 22 | tokens_ds1 = TokenBlockDataset( 23 | tokens_1, 24 | sizes=[tokens_1.size(-1)], 25 | block_size=1, 26 | pad=0, 27 | eos=1, 28 | include_targets=False, 29 | ) 30 | self.dataset_1 = LanguagePairDataset( 31 | tokens_ds1, tokens_ds1.sizes, d, shuffle=False 32 | ) 33 | tokens_2 = torch.LongTensor([2]).view(1, -1) 34 | tokens_ds2 = TokenBlockDataset( 35 | tokens_2, 36 | sizes=[tokens_2.size(-1)], 37 | block_size=1, 38 | pad=0, 39 | eos=1, 40 | include_targets=False, 41 | ) 42 | self.dataset_2 = LanguagePairDataset( 43 | tokens_ds2, tokens_ds2.sizes, d, shuffle=False 44 | ) 45 | 46 | def _test_sample_helper( 47 | self, 48 | expected_sample_from_first_ds_percentage, 49 | num_samples=1000, 50 | sampling_func=None, 51 | ): 52 | # To make sure test is not flaky 53 | np.random.seed(0) 54 | if sampling_func is None: 55 | m = MultiCorpusSampledDataset( 56 | OrderedDict({0: self.dataset_1, 1: self.dataset_2}), 57 | ) 58 | else: 59 | m = MultiCorpusSampledDataset( 60 | OrderedDict({0: self.dataset_1, 1: self.dataset_2}), 61 | sampling_func=sampling_func, 62 | ) 63 | m.ordered_indices() 64 | count_sample_from_first_dataset = 0 65 | for _ in range(num_samples): 66 | if m.collater([m[0], m[1]])["net_input"]["src_tokens"][0] == 1: 67 | count_sample_from_first_dataset += 1 68 | sample_from_first_ds_percentage = ( 69 | 1.0 * count_sample_from_first_dataset / num_samples 70 | ) 71 | self.assertLess( 72 | abs( 73 | sample_from_first_ds_percentage 74 | - expected_sample_from_first_ds_percentage 75 | ), 76 | 0.01, 77 | ) 78 | 79 | def test_multi_corpus_sampled_dataset_uniform_sample(self): 80 | self._test_sample_helper(expected_sample_from_first_ds_percentage=0.5) 81 | 82 | def test_multi_corpus_sampled_dataset_weighted_sample(self): 83 | def naive_weighted_sample(weights): 84 | def f(l): 85 | v = np.random.random() 86 | agg = 0 87 | for i, weight in enumerate(weights): 88 | agg += weight 89 | if agg > v: 90 | return i 91 | 92 | return f 93 | 94 | self._test_sample_helper( 95 | expected_sample_from_first_ds_percentage=0.9, 96 | sampling_func=naive_weighted_sample(weights=[0.9, 0.1]), 97 | ) 98 | -------------------------------------------------------------------------------- /tests/test_reproducibility.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import contextlib 9 | from io import StringIO 10 | import json 11 | import os 12 | import tempfile 13 | import unittest 14 | 15 | from . import test_binaries 16 | 17 | 18 | class TestReproducibility(unittest.TestCase): 19 | 20 | def _test_reproducibility(self, name, extra_flags=None): 21 | if extra_flags is None: 22 | extra_flags = [] 23 | 24 | with tempfile.TemporaryDirectory(name) as data_dir: 25 | with contextlib.redirect_stdout(StringIO()): 26 | test_binaries.create_dummy_data(data_dir) 27 | test_binaries.preprocess_translation_data(data_dir) 28 | 29 | # train epochs 1 and 2 together 30 | stdout = StringIO() 31 | with contextlib.redirect_stdout(stdout): 32 | test_binaries.train_translation_model( 33 | data_dir, 'fconv_iwslt_de_en', [ 34 | '--dropout', '0.0', 35 | '--log-format', 'json', 36 | '--log-interval', '1', 37 | '--max-epoch', '3', 38 | ] + extra_flags, 39 | ) 40 | stdout = stdout.getvalue() 41 | train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3]) 42 | 43 | # train epoch 2, resuming from previous checkpoint 1 44 | os.rename( 45 | os.path.join(data_dir, 'checkpoint1.pt'), 46 | os.path.join(data_dir, 'checkpoint_last.pt'), 47 | ) 48 | stdout = StringIO() 49 | with contextlib.redirect_stdout(stdout): 50 | test_binaries.train_translation_model( 51 | data_dir, 'fconv_iwslt_de_en', [ 52 | '--dropout', '0.0', 53 | '--log-format', 'json', 54 | '--log-interval', '1', 55 | '--max-epoch', '3', 56 | ] + extra_flags, 57 | ) 58 | stdout = stdout.getvalue() 59 | train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3]) 60 | 61 | def cast(s): 62 | return round(float(s), 3) 63 | 64 | for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']: 65 | self.assertEqual(cast(train_log[k]), cast(train_res_log[k])) 66 | for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']: 67 | self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k])) 68 | 69 | def test_reproducibility(self): 70 | self._test_reproducibility('test_reproducibility') 71 | 72 | def test_reproducibility_fp16(self): 73 | self._test_reproducibility('test_reproducibility_fp16', [ 74 | '--fp16', 75 | '--fp16-init-scale', '4096', 76 | ]) 77 | 78 | def test_reproducibility_memory_efficient_fp16(self): 79 | self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [ 80 | '--memory-efficient-fp16', 81 | '--fp16-init-scale', '4096', 82 | ]) 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/test_token_block_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fairseq.data import TokenBlockDataset 13 | 14 | import tests.utils as test_utils 15 | 16 | 17 | class TestTokenBlockDataset(unittest.TestCase): 18 | 19 | def _build_dataset(self, data, **kwargs): 20 | sizes = [len(x) for x in data] 21 | underlying_ds = test_utils.TestDataset(data) 22 | return TokenBlockDataset(underlying_ds, sizes, **kwargs) 23 | 24 | def test_eos_break_mode(self): 25 | data = [ 26 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long), 27 | torch.tensor([1], dtype=torch.long), 28 | torch.tensor([8, 7, 6, 1], dtype=torch.long), 29 | ] 30 | ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') 31 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) 32 | self.assertEqual(ds[1].tolist(), [1]) 33 | self.assertEqual(ds[2].tolist(), [8, 7, 6, 1]) 34 | 35 | data = [ 36 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long), 37 | torch.tensor([8, 7, 6, 1], dtype=torch.long), 38 | torch.tensor([1], dtype=torch.long), 39 | ] 40 | ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') 41 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) 42 | self.assertEqual(ds[1].tolist(), [8, 7, 6, 1]) 43 | self.assertEqual(ds[2].tolist(), [1]) 44 | 45 | def test_block_break_mode(self): 46 | data = [ 47 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long), 48 | torch.tensor([8, 7, 6, 1], dtype=torch.long), 49 | torch.tensor([9, 1], dtype=torch.long), 50 | ] 51 | ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none') 52 | self.assertEqual(ds[0].tolist(), [5, 4, 3]) 53 | self.assertEqual(ds[1].tolist(), [2, 1, 8]) 54 | self.assertEqual(ds[2].tolist(), [7, 6, 1]) 55 | self.assertEqual(ds[3].tolist(), [9, 1]) 56 | 57 | def test_complete_break_mode(self): 58 | data = [ 59 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long), 60 | torch.tensor([8, 7, 6, 1], dtype=torch.long), 61 | torch.tensor([9, 1], dtype=torch.long), 62 | ] 63 | ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete') 64 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) 65 | self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1]) 66 | 67 | data = [ 68 | torch.tensor([4, 3, 2, 1], dtype=torch.long), 69 | torch.tensor([5, 1], dtype=torch.long), 70 | torch.tensor([1], dtype=torch.long), 71 | torch.tensor([6, 1], dtype=torch.long), 72 | ] 73 | ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete') 74 | self.assertEqual(ds[0].tolist(), [4, 3, 2, 1]) 75 | self.assertEqual(ds[1].tolist(), [5, 1, 1]) 76 | self.assertEqual(ds[2].tolist(), [6, 1]) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fairseq import utils 13 | 14 | 15 | class TestUtils(unittest.TestCase): 16 | 17 | def test_convert_padding_direction(self): 18 | pad = 1 19 | left_pad = torch.LongTensor([ 20 | [2, 3, 4, 5, 6], 21 | [1, 7, 8, 9, 10], 22 | [1, 1, 1, 11, 12], 23 | ]) 24 | right_pad = torch.LongTensor([ 25 | [2, 3, 4, 5, 6], 26 | [7, 8, 9, 10, 1], 27 | [11, 12, 1, 1, 1], 28 | ]) 29 | 30 | self.assertAlmostEqual( 31 | right_pad, 32 | utils.convert_padding_direction( 33 | left_pad, 34 | pad, 35 | left_to_right=True, 36 | ), 37 | ) 38 | self.assertAlmostEqual( 39 | left_pad, 40 | utils.convert_padding_direction( 41 | right_pad, 42 | pad, 43 | right_to_left=True, 44 | ), 45 | ) 46 | 47 | def test_make_positions(self): 48 | pad = 1 49 | left_pad_input = torch.LongTensor([ 50 | [9, 9, 9, 9, 9], 51 | [1, 9, 9, 9, 9], 52 | [1, 1, 1, 9, 9], 53 | ]) 54 | left_pad_output = torch.LongTensor([ 55 | [2, 3, 4, 5, 6], 56 | [1, 2, 3, 4, 5], 57 | [1, 1, 1, 2, 3], 58 | ]) 59 | right_pad_input = torch.LongTensor([ 60 | [9, 9, 9, 9, 9], 61 | [9, 9, 9, 9, 1], 62 | [9, 9, 1, 1, 1], 63 | ]) 64 | right_pad_output = torch.LongTensor([ 65 | [2, 3, 4, 5, 6], 66 | [2, 3, 4, 5, 1], 67 | [2, 3, 1, 1, 1], 68 | ]) 69 | 70 | self.assertAlmostEqual( 71 | left_pad_output, 72 | utils.make_positions(left_pad_input, pad), 73 | ) 74 | self.assertAlmostEqual( 75 | right_pad_output, 76 | utils.make_positions(right_pad_input, pad), 77 | ) 78 | 79 | def assertAlmostEqual(self, t1, t2): 80 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 81 | self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) 82 | 83 | 84 | if __name__ == '__main__': 85 | unittest.main() 86 | --------------------------------------------------------------------------------