├── .DS_Store ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── SDT_train.sh ├── docs ├── Makefile ├── _static │ └── theme_overrides.css ├── command_line_tools.rst ├── conf.py ├── criterions.rst ├── data.rst ├── docutils.conf ├── getting_started.rst ├── index.rst ├── lr_scheduler.rst ├── make.bat ├── models.rst ├── modules.rst ├── optim.rst ├── overview.rst ├── requirements.txt ├── tasks.rst ├── tutorial_classifying_names.rst └── tutorial_simple_lstm.rst ├── eval_lm.py ├── examples ├── .gitignore ├── backtranslation │ └── README.md ├── conv_seq2seq │ └── README.md ├── language_model │ ├── README.md │ ├── conv_lm │ │ └── README.md │ ├── prepare-wikitext-103.sh │ └── transformer_lm │ │ └── README.md ├── pay_less_attention_paper │ └── README.md ├── scaling_nmt │ └── README.md ├── stories │ └── README.md ├── translation │ ├── README.md │ ├── prepare-iwslt14.sh │ ├── prepare-iwslt17-multilingual.sh │ ├── prepare-wmt14en2de.sh │ └── prepare-wmt14en2fr.sh └── translation_moe │ └── README.md ├── fairseq.gif ├── fairseq ├── .DS_Store ├── __init__.py ├── binarizer.py ├── bleu.py ├── clib │ └── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── composite_loss.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ ├── label_smoothed_adaptive_loss.py │ ├── label_smoothed_cross_entropy.py │ └── regularization_label_smoothed_cross_entropy.py ├── data │ ├── __init__.py │ ├── backtranslation_dataset.py │ ├── concat_dataset.py │ ├── data_utils.py │ ├── dictionary.py │ ├── fairseq_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── lm_context_window_dataset.py │ ├── monolingual_dataset.py │ ├── noising.py │ ├── round_robin_zip_datasets.py │ ├── token_block_dataset.py │ └── transform_eos_dataset.py ├── distributed_utils.py ├── legacy_distributed_data_parallel.py ├── meters.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── composite_encoder.py │ ├── distributed_fairseq_model.py │ ├── dlcl_transformer.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_self_att.py │ ├── lightconv.py │ ├── lstm.py │ ├── multilingual_transformer.py │ ├── sdt_transformer.py │ └── transformer.py ├── modules │ ├── .DS_Store │ ├── __init__.py │ ├── adaptive_input.py │ ├── adaptive_softmax.py │ ├── bak.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── downsampled_multihead_attention.py │ ├── dynamic_convolution.py │ ├── grad_multiply.py │ ├── highway.py │ ├── layer_history.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightweight_convolution.py │ ├── linearized_convolution.py │ ├── logsumexp_moe.py │ ├── mean_pool_gating_network.py │ ├── multihead_attention.py │ ├── relative_multihead_attention.py │ ├── scalar_bias.py │ ├── sinusoidal_positional_embedding.py │ └── unfold1d.py ├── optim │ ├── __init__.py │ ├── adadelta.py │ ├── adafactor.py │ ├── adagrad.py │ ├── adam.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ └── sgd.py ├── options.py ├── pdb.py ├── progress_bar.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── fairseq_task.py │ ├── language_modeling.py │ ├── multilingual_translation.py │ ├── translation.py │ └── translation_moe.py ├── tokenizer.py ├── trainer.py └── utils.py ├── fairseq_cli ├── __init__.py ├── eval_lm.py ├── generate.py ├── interactive.py ├── preprocess.py ├── score.py ├── setup.py └── train.py ├── fairseq_logo.png ├── generate.py ├── interactive.py ├── preprocess.py ├── preprocess.sh ├── rerank.py ├── score.py ├── scripts ├── __init__.py ├── average_checkpoints.py ├── build_sym_alignment.py ├── compound_split_bleu.sh ├── convert_dictionary.lua ├── convert_model.lua ├── read_binarized.py ├── sacrebleu_pregen.sh ├── score_moe.py ├── spm_decode.py ├── spm_encode.py └── spm_train.py ├── setup.py ├── stack.py ├── tests ├── __init__.py ├── test_average_checkpoints.py ├── test_backtranslation_dataset.py ├── test_binaries.py ├── test_character_token_embedder.py ├── test_convtbc.py ├── test_dictionary.py ├── test_iterators.py ├── test_label_smoothing.py ├── test_noising.py ├── test_reproducibility.py ├── test_sequence_generator.py ├── test_sequence_scorer.py ├── test_token_block_dataset.py ├── test_train.py ├── test_utils.py └── utils.py ├── train.py ├── train.sh └── translate.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBrains PyCharm IDE 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # Checkpoints 32 | checkpoints 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # Generated files 110 | fairseq/temporal_convolution_tbc 111 | 112 | # data 113 | data-bin/ 114 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch) 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## Coding Style 26 | We try to follow the PEP style guidelines and encourage you to as well. 27 | 28 | ## License 29 | By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed 30 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fairseq software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fairseq software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = fairseq 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | .wy-table-responsive table td kbd { 2 | white-space: nowrap; 3 | } 4 | .wy-table-responsive table td { 5 | white-space: normal !important; 6 | } 7 | .wy-table-responsive { 8 | overflow: visible !important; 9 | } 10 | -------------------------------------------------------------------------------- /docs/command_line_tools.rst: -------------------------------------------------------------------------------- 1 | .. _Command-line Tools: 2 | 3 | Command-line Tools 4 | ================== 5 | 6 | Fairseq provides several command-line tools for training and evaluating models: 7 | 8 | - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data 9 | - :ref:`fairseq-train`: Train a new model on one or multiple GPUs 10 | - :ref:`fairseq-generate`: Translate pre-processed data with a trained model 11 | - :ref:`fairseq-interactive`: Translate raw text with a trained model 12 | - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations 13 | - :ref:`fairseq-eval-lm`: Language model evaluation 14 | 15 | 16 | .. _fairseq-preprocess: 17 | 18 | fairseq-preprocess 19 | ~~~~~~~~~~~~~~~~~~ 20 | .. automodule:: preprocess 21 | 22 | .. argparse:: 23 | :module: fairseq.options 24 | :func: get_preprocessing_parser 25 | :prog: fairseq-preprocess 26 | 27 | 28 | .. _fairseq-train: 29 | 30 | fairseq-train 31 | ~~~~~~~~~~~~~ 32 | .. automodule:: train 33 | 34 | .. argparse:: 35 | :module: fairseq.options 36 | :func: get_training_parser 37 | :prog: fairseq-train 38 | 39 | 40 | .. _fairseq-generate: 41 | 42 | fairseq-generate 43 | ~~~~~~~~~~~~~~~~ 44 | .. automodule:: generate 45 | 46 | .. argparse:: 47 | :module: fairseq.options 48 | :func: get_generation_parser 49 | :prog: fairseq-generate 50 | 51 | 52 | .. _fairseq-interactive: 53 | 54 | fairseq-interactive 55 | ~~~~~~~~~~~~~~~~~~~ 56 | .. automodule:: interactive 57 | 58 | .. argparse:: 59 | :module: fairseq.options 60 | :func: get_interactive_generation_parser 61 | :prog: fairseq-interactive 62 | 63 | 64 | .. _fairseq-score: 65 | 66 | fairseq-score 67 | ~~~~~~~~~~~~~ 68 | .. automodule:: score 69 | 70 | .. argparse:: 71 | :module: fairseq_cli.score 72 | :func: get_parser 73 | :prog: fairseq-score 74 | 75 | 76 | .. _fairseq-eval-lm: 77 | 78 | fairseq-eval-lm 79 | ~~~~~~~~~~~~~~~ 80 | .. automodule:: eval_lm 81 | 82 | .. argparse:: 83 | :module: fairseq.options 84 | :func: get_eval_lm_parser 85 | :prog: fairseq-eval-lm 86 | -------------------------------------------------------------------------------- /docs/criterions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Criterions: 5 | 6 | Criterions 7 | ========== 8 | 9 | Criterions compute the loss function given the model and batch, roughly:: 10 | 11 | loss = criterion(model, batch) 12 | 13 | .. automodule:: fairseq.criterions 14 | :members: 15 | 16 | .. autoclass:: fairseq.criterions.FairseqCriterion 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss 21 | :members: 22 | :undoc-members: 23 | .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss 24 | :members: 25 | :undoc-members: 26 | .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion 27 | :members: 28 | :undoc-members: 29 | .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion 30 | :members: 31 | :undoc-members: 32 | -------------------------------------------------------------------------------- /docs/data.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.data 5 | 6 | Data Loading and Utilities 7 | ========================== 8 | 9 | .. _datasets: 10 | 11 | Datasets 12 | -------- 13 | 14 | **Datasets** define the data format and provide helpers for creating 15 | mini-batches. 16 | 17 | .. autoclass:: fairseq.data.FairseqDataset 18 | :members: 19 | .. autoclass:: fairseq.data.LanguagePairDataset 20 | :members: 21 | .. autoclass:: fairseq.data.MonolingualDataset 22 | :members: 23 | 24 | **Helper Datasets** 25 | 26 | These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and 27 | provide additional functionality: 28 | 29 | .. autoclass:: fairseq.data.BacktranslationDataset 30 | :members: 31 | .. autoclass:: fairseq.data.ConcatDataset 32 | :members: 33 | .. autoclass:: fairseq.data.RoundRobinZipDatasets 34 | :members: 35 | .. autoclass:: fairseq.data.TransformEosDataset 36 | :members: 37 | 38 | 39 | Dictionary 40 | ---------- 41 | 42 | .. autoclass:: fairseq.data.Dictionary 43 | :members: 44 | 45 | 46 | Iterators 47 | --------- 48 | 49 | .. autoclass:: fairseq.data.CountingIterator 50 | :members: 51 | .. autoclass:: fairseq.data.EpochBatchIterator 52 | :members: 53 | .. autoclass:: fairseq.data.GroupedIterator 54 | :members: 55 | .. autoclass:: fairseq.data.ShardedIterator 56 | :members: 57 | -------------------------------------------------------------------------------- /docs/docutils.conf: -------------------------------------------------------------------------------- 1 | [writers] 2 | option-limit=0 3 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fairseq documentation master file, created by 2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pytorch/fairseq 7 | 8 | 9 | fairseq documentation 10 | ===================== 11 | 12 | Fairseq is a sequence modeling toolkit written in `PyTorch 13 | `_ that allows researchers and developers to 14 | train custom models for translation, summarization, language modeling and other 15 | text generation tasks. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Getting Started 20 | 21 | getting_started 22 | command_line_tools 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Extending Fairseq 27 | 28 | overview 29 | tutorial_simple_lstm 30 | tutorial_classifying_names 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | :caption: Library Reference 35 | 36 | tasks 37 | models 38 | criterions 39 | optim 40 | lr_scheduler 41 | data 42 | modules 43 | 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`search` 50 | -------------------------------------------------------------------------------- /docs/lr_scheduler.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Learning Rate Schedulers: 5 | 6 | Learning Rate Schedulers 7 | ======================== 8 | 9 | Learning Rate Schedulers update the learning rate over the course of training. 10 | Learning rates can be updated after each update via :func:`step_update` or at 11 | epoch boundaries via :func:`step`. 12 | 13 | .. automodule:: fairseq.optim.lr_scheduler 14 | :members: 15 | 16 | .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule 21 | :members: 22 | :undoc-members: 23 | .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule 24 | :members: 25 | :undoc-members: 26 | .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule 27 | :members: 28 | :undoc-members: 29 | .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau 30 | :members: 31 | :undoc-members: 32 | .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule 33 | :members: 34 | :undoc-members: 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fairseq 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.models 5 | 6 | .. _Models: 7 | 8 | Models 9 | ====== 10 | 11 | A Model defines the neural network's ``forward()`` method and encapsulates all 12 | of the learnable parameters in the network. Each model also provides a set of 13 | named *architectures* that define the precise network configuration (e.g., 14 | embedding dimension, number of layers, etc.). 15 | 16 | Both the model type and architecture are selected via the ``--arch`` 17 | command-line argument. Once selected, a model may expose additional command-line 18 | arguments for further configuration. 19 | 20 | .. note:: 21 | 22 | All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends 23 | :class:`torch.nn.Module`. Thus any fairseq Model can be used as a 24 | stand-alone Module in other PyTorch code. 25 | 26 | 27 | Convolutional Neural Networks (CNN) 28 | ----------------------------------- 29 | 30 | .. module:: fairseq.models.fconv 31 | .. autoclass:: fairseq.models.fconv.FConvModel 32 | :members: 33 | .. autoclass:: fairseq.models.fconv.FConvEncoder 34 | :members: 35 | :undoc-members: 36 | .. autoclass:: fairseq.models.fconv.FConvDecoder 37 | :members: 38 | 39 | 40 | Long Short-Term Memory (LSTM) networks 41 | -------------------------------------- 42 | 43 | .. module:: fairseq.models.lstm 44 | .. autoclass:: fairseq.models.lstm.LSTMModel 45 | :members: 46 | .. autoclass:: fairseq.models.lstm.LSTMEncoder 47 | :members: 48 | .. autoclass:: fairseq.models.lstm.LSTMDecoder 49 | :members: 50 | 51 | 52 | Transformer (self-attention) networks 53 | ------------------------------------- 54 | 55 | .. module:: fairseq.models.transformer 56 | .. autoclass:: fairseq.models.transformer.TransformerModel 57 | :members: 58 | .. autoclass:: fairseq.models.transformer.TransformerEncoder 59 | :members: 60 | .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer 61 | :members: 62 | .. autoclass:: fairseq.models.transformer.TransformerDecoder 63 | :members: 64 | .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer 65 | :members: 66 | 67 | 68 | Adding new models 69 | ----------------- 70 | 71 | .. currentmodule:: fairseq.models 72 | .. autofunction:: fairseq.models.register_model 73 | .. autofunction:: fairseq.models.register_model_architecture 74 | .. autoclass:: fairseq.models.BaseFairseqModel 75 | :members: 76 | :undoc-members: 77 | .. autoclass:: fairseq.models.FairseqModel 78 | :members: 79 | :undoc-members: 80 | .. autoclass:: fairseq.models.FairseqLanguageModel 81 | :members: 82 | :undoc-members: 83 | .. autoclass:: fairseq.models.FairseqEncoder 84 | :members: 85 | .. autoclass:: fairseq.models.CompositeEncoder 86 | :members: 87 | .. autoclass:: fairseq.models.FairseqDecoder 88 | :members: 89 | 90 | 91 | .. _Incremental decoding: 92 | 93 | Incremental decoding 94 | -------------------- 95 | 96 | .. autoclass:: fairseq.models.FairseqIncrementalDecoder 97 | :members: 98 | :undoc-members: 99 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may 5 | be helpful when implementing a new :class:`~fairseq.models.FairseqModel`. 6 | 7 | .. automodule:: fairseq.modules 8 | :members: 9 | :undoc-members: 10 | -------------------------------------------------------------------------------- /docs/optim.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _optimizers: 5 | 6 | Optimizers 7 | ========== 8 | 9 | Optimizers update the Model parameters based on the gradients. 10 | 11 | .. automodule:: fairseq.optim 12 | :members: 13 | 14 | .. autoclass:: fairseq.optim.FairseqOptimizer 15 | :members: 16 | :undoc-members: 17 | 18 | .. autoclass:: fairseq.optim.adagrad.Adagrad 19 | :members: 20 | :undoc-members: 21 | .. autoclass:: fairseq.optim.adam.FairseqAdam 22 | :members: 23 | :undoc-members: 24 | .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer 25 | :members: 26 | :undoc-members: 27 | .. autoclass:: fairseq.optim.nag.FairseqNAG 28 | :members: 29 | :undoc-members: 30 | .. autoclass:: fairseq.optim.sgd.SGD 31 | :members: 32 | :undoc-members: 33 | -------------------------------------------------------------------------------- /docs/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | Fairseq can be extended through user-supplied `plug-ins 5 | `_. We support five kinds of 6 | plug-ins: 7 | 8 | - :ref:`Models` define the neural network architecture and encapsulate all of the 9 | learnable parameters. 10 | - :ref:`Criterions` compute the loss function given the model outputs and targets. 11 | - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | - :ref:`Optimizers` update the Model parameters based on the gradients. 14 | - :ref:`Learning Rate Schedulers` update the learning rate over the course of 15 | training. 16 | 17 | **Training Flow** 18 | 19 | Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``, 20 | fairseq implements the following high-level training flow:: 21 | 22 | for epoch in range(num_epochs): 23 | itr = task.get_batch_iterator(task.dataset('train')) 24 | for num_updates, batch in enumerate(itr): 25 | task.train_step(batch, model, criterion, optimizer) 26 | average_and_clip_gradients() 27 | optimizer.step() 28 | lr_scheduler.step_update(num_updates) 29 | lr_scheduler.step(epoch) 30 | 31 | where the default implementation for ``train.train_step`` is roughly:: 32 | 33 | def train_step(self, batch, model, criterion, optimizer): 34 | loss = criterion(model, batch) 35 | optimizer.backward(loss) 36 | 37 | **Registering new plug-ins** 38 | 39 | New plug-ins are *registered* through a set of ``@register`` function 40 | decorators, for example:: 41 | 42 | @register_model('my_lstm') 43 | class MyLSTM(FairseqModel): 44 | (...) 45 | 46 | Once registered, new plug-ins can be used with the existing :ref:`Command-line 47 | Tools`. See the Tutorial sections for more detailed walkthroughs of how to add 48 | new plug-ins. 49 | 50 | **Loading plug-ins from another directory** 51 | 52 | New plug-ins can be defined in a custom module stored in the user system. In 53 | order to import the module, and make the plugin available to *fairseq*, the 54 | command line supports the ``--user-dir`` flag that can be used to specify a 55 | custom location for additional modules to load into *fairseq*. 56 | 57 | For example, assuming this directory tree:: 58 | 59 | /home/user/my-module/ 60 | └── __init__.py 61 | 62 | with ``__init__.py``:: 63 | 64 | from fairseq.models import register_model_architecture 65 | from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big 66 | 67 | @register_model_architecture('transformer', 'my_transformer') 68 | def transformer_mmt_big(args): 69 | transformer_vaswani_wmt_en_de_big(args) 70 | 71 | it is possible to invoke the :ref:`fairseq-train` script with the new architecture with:: 72 | 73 | fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation 74 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<2.0 2 | sphinx-argparse 3 | -------------------------------------------------------------------------------- /docs/tasks.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.tasks 5 | 6 | .. _Tasks: 7 | 8 | Tasks 9 | ===== 10 | 11 | Tasks store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | 14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a 15 | task may expose additional command-line arguments for further configuration. 16 | 17 | Example usage:: 18 | 19 | # setup the task (e.g., load dictionaries) 20 | task = fairseq.tasks.setup_task(args) 21 | 22 | # build model and criterion 23 | model = task.build_model(args) 24 | criterion = task.build_criterion(args) 25 | 26 | # load datasets 27 | task.load_dataset('train') 28 | task.load_dataset('valid') 29 | 30 | # iterate over mini-batches of data 31 | batch_itr = task.get_batch_iterator( 32 | task.dataset('train'), max_tokens=4096, 33 | ) 34 | for batch in batch_itr: 35 | # compute the loss 36 | loss, sample_size, logging_output = task.get_loss( 37 | model, criterion, batch, 38 | ) 39 | loss.backward() 40 | 41 | 42 | Translation 43 | ----------- 44 | 45 | .. autoclass:: fairseq.tasks.translation.TranslationTask 46 | 47 | .. _language modeling: 48 | 49 | Language Modeling 50 | ----------------- 51 | 52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask 53 | 54 | 55 | Adding new tasks 56 | ---------------- 57 | 58 | .. autofunction:: fairseq.tasks.register_task 59 | .. autoclass:: fairseq.tasks.FairseqTask 60 | :members: 61 | :undoc-members: 62 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | !*/*.sh 2 | !*/*.md 3 | -------------------------------------------------------------------------------- /examples/backtranslation/README.md: -------------------------------------------------------------------------------- 1 | # Understanding Back-Translation at Scale (Edunov et al., 2018) 2 | 3 | This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381). 4 | 5 | ## Pre-trained models 6 | 7 | Description | Dataset | Model | Test set(s) 8 | ---|---|---|--- 9 | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive 10 | 11 | ## Citation 12 | ```bibtex 13 | @inproceedings{edunov2018backtranslation, 14 | title = {Understanding Back-Translation at Scale}, 15 | author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David}, 16 | booktitle = {Conference of the Association for Computational Linguistics (ACL)}, 17 | year = 2018, 18 | } 19 | ``` 20 | -------------------------------------------------------------------------------- /examples/conv_seq2seq/README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Sequence to Sequence Learning (Gehring et al., 2017) 2 | 3 | ## Pre-trained models 4 | 5 | Description | Dataset | Model | Test set(s) 6 | ---|---|---|--- 7 | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) 8 | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) 9 | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) 10 | 11 | ## Example usage 12 | 13 | See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and 14 | WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures. 15 | 16 | ## Citation 17 | 18 | ```bibtex 19 | @inproceedings{gehring2017convs2s, 20 | title = {Convolutional Sequence to Sequence Learning}, 21 | author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, 22 | booktitle = {Proc. of ICML}, 23 | year = 2017, 24 | } 25 | ``` 26 | -------------------------------------------------------------------------------- /examples/language_model/README.md: -------------------------------------------------------------------------------- 1 | # Neural Language Modeling 2 | 3 | ## Pre-trained models 4 | 5 | Description | Parameters | Dataset | Model and Test set(s) 6 | ---|---:|---|--- 7 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2) 8 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2) 9 | 10 | ## Example usage 11 | 12 | These scripts provide an example of pre-processing data for the Language Modeling task. 13 | 14 | ### prepare-wikitext-103.sh 15 | 16 | Provides an example of pre-processing for [WikiText-103 language modeling task](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/): 17 | 18 | Example usage: 19 | 20 | Prepare data: 21 | ``` 22 | $ cd examples/language_model/ 23 | $ bash prepare-wikitext-103.sh 24 | $ cd ../.. 25 | 26 | # Binarize the dataset: 27 | $ TEXT=examples/language_model/wikitext-103 28 | 29 | $ fairseq-preprocess --only-source \ 30 | --trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \ 31 | --destdir data-bin/wikitext-103 32 | ``` 33 | 34 | Train a transformer language model with adaptive inputs ([Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](transformer_lm/README.md)): 35 | ``` 36 | # If it runs out of memory, try to reduce max-tokens and tokens-per-sample 37 | $ mkdir -p checkpoints/transformer_wikitext-103 38 | $ fairseq-train --task language_modeling data-bin/wikitext-103 \ 39 | --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \ 40 | --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ 41 | --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ 42 | --criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \ 43 | --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d 44 | 45 | # Evaluate: 46 | $ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \ 47 | --sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024 48 | 49 | ``` 50 | 51 | 52 | Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)): 53 | ``` 54 | # If it runs out of memory, try to reduce max-tokens and tokens-per-sample 55 | $ mkdir -p checkpoints/fconv_wikitext-103 56 | $ fairseq-train --task language_modeling data-bin/wikitext-103 \ 57 | --save-dir checkpoints/fconv_wikitext-103 \ 58 | --max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \ 59 | --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ 60 | --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \ 61 | --adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024 62 | --ddp-backend=no_c10d 63 | 64 | # Evaluate: 65 | $ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt' 66 | 67 | ``` 68 | -------------------------------------------------------------------------------- /examples/language_model/conv_lm/README.md: -------------------------------------------------------------------------------- 1 | # Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017) 2 | 3 | ## Example usage 4 | 5 | See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103 6 | using the `fconv_lm_dauphin_wikitext103` model architecture. 7 | 8 | ## Citation 9 | 10 | ```bibtex 11 | @inproceedings{dauphin2017language, 12 | title={Language Modeling with Gated Convolutional Networks}, 13 | author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David}, 14 | booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70}, 15 | pages={933--941}, 16 | year={2017}, 17 | organization={JMLR} 18 | } 19 | ``` 20 | -------------------------------------------------------------------------------- /examples/language_model/prepare-wikitext-103.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | URLS=( 5 | "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" 6 | ) 7 | FILES=( 8 | "wikitext-103-v1.zip" 9 | ) 10 | 11 | for ((i=0;i<${#URLS[@]};++i)); do 12 | file=${FILES[i]} 13 | if [ -f $file ]; then 14 | echo "$file already exists, skipping download" 15 | else 16 | url=${URLS[i]} 17 | wget "$url" 18 | if [ -f $file ]; then 19 | echo "$url successfully downloaded." 20 | else 21 | echo "$url not successfully downloaded." 22 | exit -1 23 | fi 24 | if [ ${file: -4} == ".tgz" ]; then 25 | tar zxvf $file 26 | elif [ ${file: -4} == ".tar" ]; then 27 | tar xvf $file 28 | elif [ ${file: -4} == ".zip" ]; then 29 | unzip $file 30 | fi 31 | fi 32 | done 33 | cd .. 34 | -------------------------------------------------------------------------------- /examples/language_model/transformer_lm/README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Input Representations for Neural Language Modeling (Baevski and Auli; 2018) 2 | 3 | ## Pre-trained models 4 | 5 | Description | Parameters | Dataset | Model and Test set(s) 6 | ---|---:|---|--- 7 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2) 8 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2) 9 | 10 | ## Example usage 11 | 12 | See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103 13 | using the `transformer_lm_wiki103` model architecture. 14 | 15 | ## Citation 16 | 17 | ```bibtex 18 | @inproceedings{ 19 | baevski2018adaptive, 20 | title={Adaptive Input Representations for Neural Language Modeling}, 21 | author={Alexei Baevski and Michael Auli}, 22 | booktitle={International Conference on Learning Representations}, 23 | year={2019}, 24 | url={https://openreview.net/forum?id=ByxZX20qFQ}, 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/scaling_nmt/README.md: -------------------------------------------------------------------------------- 1 | # Scaling Neural Machine Translation (Ott et al., 2018) 2 | 3 | This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187). 4 | 5 | ## Pre-trained models 6 | 7 | Description | Dataset | Model | Test set(s) 8 | ---|---|---|--- 9 | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) 10 | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab):
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) 11 | 12 | ## Training a new model on WMT'16 En-De 13 | 14 | Please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8). 15 | Then: 16 | 17 | 1. Extract the WMT'16 En-De data: 18 | ``` 19 | $ TEXT=wmt16_en_de_bpe32k 20 | $ mkdir $TEXT 21 | $ tar -xzvf wmt16_en_de.tar.gz -C $TEXT 22 | ``` 23 | 24 | 2. Preprocess the dataset with a joined dictionary: 25 | ``` 26 | $ fairseq-preprocess --source-lang en --target-lang de \ 27 | --trainpref $TEXT/train.tok.clean.bpe.32000 \ 28 | --validpref $TEXT/newstest2013.tok.bpe.32000 \ 29 | --testpref $TEXT/newstest2014.tok.bpe.32000 \ 30 | --destdir data-bin/wmt16_en_de_bpe32k \ 31 | --nwordssrc 32768 --nwordstgt 32768 \ 32 | --joined-dictionary 33 | ``` 34 | 35 | 3. Train a model: 36 | ``` 37 | $ fairseq-train data-bin/wmt16_en_de_bpe32k \ 38 | --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ 39 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 40 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ 41 | --lr 0.0005 --min-lr 1e-09 \ 42 | --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 43 | --max-tokens 3584 \ 44 | --fp16 45 | ``` 46 | 47 | Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU. 48 | 49 | If you want to train the above model with big batches (assuming your machine has 8 GPUs): 50 | - add `--update-freq 16` to simulate training on 8*16=128 GPUs 51 | - increase the learning rate; 0.001 works well for big batches 52 | 53 | ## Citation 54 | 55 | ```bibtex 56 | @inproceedings{ott2018scaling, 57 | title = {Scaling Neural Machine Translation}, 58 | author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael}, 59 | booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)}, 60 | year = 2018, 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /examples/stories/README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Neural Story Generation (Fan et al., 2018) 2 | 3 | The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset. 4 | 5 | ## Pre-trained models 6 | 7 | Description | Dataset | Model | Test set(s) 8 | ---|---|---|--- 9 | Stories with Convolutional Model
([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2) 10 | 11 | 12 | ## Dataset 13 | 14 | The dataset can be downloaded like this: 15 | 16 | ``` 17 | cd examples/stories 18 | curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf - 19 | ``` 20 | 21 | and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token. 22 | 23 | 24 | ## Example usage 25 | 26 | ``` 27 | # Preprocess the dataset: 28 | # Note that the dataset release is the full data, but the paper models the first 1000 words of each story 29 | # Here is some example code that can trim the dataset to the first 1000 words of each story 30 | $ python 31 | $ data = ["train", "test", "valid"] 32 | $ for name in data: 33 | $ with open(name + ".wp_target") as f: 34 | $ stories = f.readlines() 35 | $ stories = [" ".join(i.split()[0:1000]) for i in stories] 36 | $ with open(name + ".wp_target", "w") as o: 37 | $ for line in stories: 38 | $ o.write(line.strip() + "\n") 39 | 40 | # Binarize the dataset: 41 | $ export TEXT=examples/stories/writingPrompts 42 | $ fairseq-preprocess --source-lang wp_source --target-lang wp_target \ 43 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 44 | --destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10 45 | 46 | # Train the model: 47 | $ fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False 48 | 49 | # Train a fusion model: 50 | # add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint 51 | 52 | # Generate: 53 | # Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary. 54 | 55 | $ fairseq-generate data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}" 56 | ``` 57 | 58 | ## Citation 59 | ```bibtex 60 | @inproceedings{fan2018hierarchical, 61 | title = {Hierarchical Neural Story Generation}, 62 | author = {Fan, Angela and Lewis, Mike and Dauphin, Yann}, 63 | booktitle = {Conference of the Association for Computational Linguistics (ACL)}, 64 | year = 2018, 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /examples/translation/prepare-iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 4 | 5 | echo 'Cloning Moses github repository (for tokenization scripts)...' 6 | git clone https://github.com/moses-smt/mosesdecoder.git 7 | 8 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 9 | git clone https://github.com/rsennrich/subword-nmt.git 10 | 11 | SCRIPTS=mosesdecoder/scripts 12 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 13 | LC=$SCRIPTS/tokenizer/lowercase.perl 14 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=10000 17 | 18 | URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" 19 | GZ=de-en.tgz 20 | 21 | if [ ! -d "$SCRIPTS" ]; then 22 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 23 | exit 24 | fi 25 | 26 | src=de 27 | tgt=en 28 | lang=de-en 29 | prep=iwslt14.tokenized.de-en 30 | tmp=$prep/tmp 31 | orig=orig 32 | 33 | mkdir -p $orig $tmp $prep 34 | 35 | echo "Downloading data from ${URL}..." 36 | cd $orig 37 | wget "$URL" 38 | 39 | if [ -f $GZ ]; then 40 | echo "Data successfully downloaded." 41 | else 42 | echo "Data not successfully downloaded." 43 | exit 44 | fi 45 | 46 | tar zxvf $GZ 47 | cd .. 48 | 49 | echo "pre-processing train data..." 50 | for l in $src $tgt; do 51 | f=train.tags.$lang.$l 52 | tok=train.tags.$lang.tok.$l 53 | 54 | cat $orig/$lang/$f | \ 55 | grep -v '' | \ 56 | grep -v '' | \ 57 | grep -v '' | \ 58 | sed -e 's///g' | \ 59 | sed -e 's/<\/title>//g' | \ 60 | sed -e 's/<description>//g' | \ 61 | sed -e 's/<\/description>//g' | \ 62 | perl $TOKENIZER -threads 8 -l $l > $tmp/$tok 63 | echo "" 64 | done 65 | perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175 66 | for l in $src $tgt; do 67 | perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l 68 | done 69 | 70 | echo "pre-processing valid/test data..." 71 | for l in $src $tgt; do 72 | for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do 73 | fname=${o##*/} 74 | f=$tmp/${fname%.*} 75 | echo $o $f 76 | grep '<seg id' $o | \ 77 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 78 | sed -e 's/\s*<\/seg>\s*//g' | \ 79 | sed -e "s/\’/\'/g" | \ 80 | perl $TOKENIZER -threads 8 -l $l | \ 81 | perl $LC > $f 82 | echo "" 83 | done 84 | done 85 | 86 | 87 | echo "creating train, valid, test..." 88 | for l in $src $tgt; do 89 | awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l 90 | awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l 91 | 92 | cat $tmp/IWSLT14.TED.dev2010.de-en.$l \ 93 | $tmp/IWSLT14.TEDX.dev2012.de-en.$l \ 94 | $tmp/IWSLT14.TED.tst2010.de-en.$l \ 95 | $tmp/IWSLT14.TED.tst2011.de-en.$l \ 96 | $tmp/IWSLT14.TED.tst2012.de-en.$l \ 97 | > $tmp/test.$l 98 | done 99 | 100 | TRAIN=$tmp/train.en-de 101 | BPE_CODE=$prep/code 102 | rm -f $TRAIN 103 | for l in $src $tgt; do 104 | cat $tmp/train.$l >> $TRAIN 105 | done 106 | 107 | echo "learn_bpe.py on ${TRAIN}..." 108 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 109 | 110 | for L in $src $tgt; do 111 | for f in train.$L valid.$L test.$L; do 112 | echo "apply_bpe.py to ${f}..." 113 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f 114 | done 115 | done 116 | -------------------------------------------------------------------------------- /examples/translation/prepare-wmt14en2fr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | echo 'Cloning Moses github repository (for tokenization scripts)...' 5 | git clone https://github.com/moses-smt/mosesdecoder.git 6 | 7 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 8 | git clone https://github.com/rsennrich/subword-nmt.git 9 | 10 | SCRIPTS=mosesdecoder/scripts 11 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=40000 17 | 18 | URLS=( 19 | "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" 20 | "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" 21 | "http://statmt.org/wmt13/training-parallel-un.tgz" 22 | "http://statmt.org/wmt14/training-parallel-nc-v9.tgz" 23 | "http://statmt.org/wmt10/training-giga-fren.tar" 24 | "http://statmt.org/wmt14/test-full.tgz" 25 | ) 26 | FILES=( 27 | "training-parallel-europarl-v7.tgz" 28 | "training-parallel-commoncrawl.tgz" 29 | "training-parallel-un.tgz" 30 | "training-parallel-nc-v9.tgz" 31 | "training-giga-fren.tar" 32 | "test-full.tgz" 33 | ) 34 | CORPORA=( 35 | "training/europarl-v7.fr-en" 36 | "commoncrawl.fr-en" 37 | "un/undoc.2000.fr-en" 38 | "training/news-commentary-v9.fr-en" 39 | "giga-fren.release2.fixed" 40 | ) 41 | 42 | if [ ! -d "$SCRIPTS" ]; then 43 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 44 | exit 45 | fi 46 | 47 | src=en 48 | tgt=fr 49 | lang=en-fr 50 | prep=wmt14_en_fr 51 | tmp=$prep/tmp 52 | orig=orig 53 | 54 | mkdir -p $orig $tmp $prep 55 | 56 | cd $orig 57 | 58 | for ((i=0;i<${#URLS[@]};++i)); do 59 | file=${FILES[i]} 60 | if [ -f $file ]; then 61 | echo "$file already exists, skipping download" 62 | else 63 | url=${URLS[i]} 64 | wget "$url" 65 | if [ -f $file ]; then 66 | echo "$url successfully downloaded." 67 | else 68 | echo "$url not successfully downloaded." 69 | exit -1 70 | fi 71 | if [ ${file: -4} == ".tgz" ]; then 72 | tar zxvf $file 73 | elif [ ${file: -4} == ".tar" ]; then 74 | tar xvf $file 75 | fi 76 | fi 77 | done 78 | 79 | gunzip giga-fren.release2.fixed.*.gz 80 | cd .. 81 | 82 | echo "pre-processing train data..." 83 | for l in $src $tgt; do 84 | rm $tmp/train.tags.$lang.tok.$l 85 | for f in "${CORPORA[@]}"; do 86 | cat $orig/$f.$l | \ 87 | perl $NORM_PUNC $l | \ 88 | perl $REM_NON_PRINT_CHAR | \ 89 | perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l 90 | done 91 | done 92 | 93 | echo "pre-processing test data..." 94 | for l in $src $tgt; do 95 | if [ "$l" == "$src" ]; then 96 | t="src" 97 | else 98 | t="ref" 99 | fi 100 | grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \ 101 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 102 | sed -e 's/\s*<\/seg>\s*//g' | \ 103 | sed -e "s/\’/\'/g" | \ 104 | perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l 105 | echo "" 106 | done 107 | 108 | echo "splitting train and valid..." 109 | for l in $src $tgt; do 110 | awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l 111 | awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l 112 | done 113 | 114 | TRAIN=$tmp/train.fr-en 115 | BPE_CODE=$prep/code 116 | rm -f $TRAIN 117 | for l in $src $tgt; do 118 | cat $tmp/train.$l >> $TRAIN 119 | done 120 | 121 | echo "learn_bpe.py on ${TRAIN}..." 122 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 123 | 124 | for L in $src $tgt; do 125 | for f in train.$L valid.$L test.$L; do 126 | echo "apply_bpe.py to ${f}..." 127 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f 128 | done 129 | done 130 | 131 | perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 132 | perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 133 | 134 | for L in $src $tgt; do 135 | cp $tmp/bpe.test.$L $prep/test.$L 136 | done 137 | -------------------------------------------------------------------------------- /examples/translation_moe/README.md: -------------------------------------------------------------------------------- 1 | # Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019) 2 | 3 | This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816). 4 | 5 | ## Download data 6 | 7 | First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh). 8 | Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`. 9 | 10 | ## Train a model 11 | 12 | Then we can train a mixture of experts model using the `translation_moe` task. 13 | Use the `--method` flag to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`). 14 | The model is trained with online responsibility assignment and shared parameterization. 15 | 16 | The following command will train a `hMoElp` model with `3` experts: 17 | ``` 18 | $ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/wmt17_en_de \ 19 | --max-update 100000 \ 20 | --task translation_moe \ 21 | --method hMoElp --mean-pool-gating-network \ 22 | --num-experts 3 \ 23 | --arch transformer_vaswani_wmt_en_de --share-all-embeddings \ 24 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 25 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ 26 | --lr 0.0007 --min-lr 1e-09 \ 27 | --dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \ 28 | --max-tokens 3584 \ 29 | --update-freq 8 30 | ``` 31 | 32 | **Note**: the above command assumes 1 GPU, but accumulates gradients from 8 fwd/bwd passes to simulate training on 8 GPUs. 33 | You can accelerate training on up to 8 GPUs by adjusting the `CUDA_VISIBLE_DEVICES` and `--update-freq` options accordingly. 34 | 35 | ## Translate 36 | 37 | Once a model is trained, we can generate translations from different experts using the `--gen-expert` option. 38 | For example, to generate from expert 0: 39 | ``` 40 | $ fairseq-generate data-bin/wmt17_en_de \ 41 | --path checkpoints/checkpoint_best.pt \ 42 | --beam 1 --remove-bpe \ 43 | --task translation_moe \ 44 | --method hMoElp --mean-pool-gating-network \ 45 | --num-experts 3 \ 46 | --gen-expert 0 47 | ``` 48 | 49 | ## Evaluate 50 | 51 | First download a tokenized version of the WMT'14 En-De test set with multiple references: 52 | ``` 53 | $ wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok 54 | ``` 55 | 56 | Next apply BPE on the fly and run generation for each expert: 57 | ``` 58 | $ BPEROOT=examples/translation/subword-nmt/ 59 | $ BPE_CODE=examples/translation/wmt17_en_de/code 60 | $ for EXPERT in $(seq 0 2); do \ 61 | cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \ 62 | python $BPEROOT/apply_bpe.py -c $BPE_CODE | \ 63 | fairseq-interactive data-bin/wmt17_en_de \ 64 | --path checkpoints/checkpoint_best.pt \ 65 | --beam 1 --remove-bpe \ 66 | --buffer 500 --max-tokens 6000 ; \ 67 | --task translation_moe \ 68 | --method hMoElp --mean-pool-gating-network \ 69 | --num-experts 3 \ 70 | --gen-expert $EXPERT \ 71 | done > wmt14-en-de.extra_refs.tok.gen.3experts 72 | ``` 73 | 74 | Finally use `scripts/score_moe.py` to compute pairwise BLUE and average oracle BLEU: 75 | ``` 76 | $ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok 77 | pairwise BLEU: 48.26 78 | avg oracle BLEU: 49.50 79 | #refs covered: 2.11 80 | ``` 81 | This matches row 3 from Table 7 in the paper. 82 | 83 | ## Citation 84 | 85 | ```bibtex 86 | @article{shen2019mixture, 87 | title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade}, 88 | author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato}, 89 | journal = {arXiv preprint arXiv:1902.07816}, 90 | year = 2019, 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /fairseq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/fairseq.gif -------------------------------------------------------------------------------- /fairseq/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/fairseq/.DS_Store -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | __all__ = ['pdb'] 9 | __version__ = '0.6.2' 10 | 11 | import fairseq.criterions 12 | import fairseq.models 13 | import fairseq.modules 14 | import fairseq.optim 15 | import fairseq.optim.lr_scheduler 16 | import fairseq.pdb 17 | import fairseq.tasks 18 | -------------------------------------------------------------------------------- /fairseq/binarizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from collections import Counter 9 | import os 10 | 11 | from fairseq.tokenizer import tokenize_line 12 | 13 | 14 | def safe_readline(f): 15 | pos = f.tell() 16 | while True: 17 | try: 18 | return f.readline() 19 | except UnicodeDecodeError: 20 | pos -= 1 21 | f.seek(pos) # search where this character begins 22 | 23 | 24 | class Binarizer: 25 | 26 | @staticmethod 27 | def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, 28 | offset=0, end=-1): 29 | nseq, ntok = 0, 0 30 | replaced = Counter() 31 | 32 | def replaced_consumer(word, idx): 33 | if idx == dict.unk_index and word != dict.unk_word: 34 | replaced.update([word]) 35 | 36 | with open(filename, 'r', encoding='utf-8') as f: 37 | f.seek(offset) 38 | # next(f) breaks f.tell(), hence readline() must be used 39 | line = safe_readline(f) 40 | while line: 41 | if end > 0 and f.tell() > end: 42 | break 43 | ids = dict.encode_line( 44 | line=line, 45 | line_tokenizer=tokenize, 46 | add_if_not_exist=False, 47 | consumer=replaced_consumer, 48 | append_eos=append_eos, 49 | reverse_order=reverse_order, 50 | ) 51 | nseq += 1 52 | ntok += len(ids) 53 | consumer(ids) 54 | line = f.readline() 55 | return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} 56 | 57 | @staticmethod 58 | def find_offsets(filename, num_chunks): 59 | with open(filename, 'r', encoding='utf-8') as f: 60 | size = os.fstat(f.fileno()).st_size 61 | chunk_size = size // num_chunks 62 | offsets = [0 for _ in range(num_chunks + 1)] 63 | for i in range(1, num_chunks): 64 | f.seek(chunk_size * i) 65 | safe_readline(f) 66 | offsets[i] = f.tell() 67 | return offsets 68 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/libbleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <map> 10 | #include <array> 11 | #include <cstring> 12 | #include <cstdio> 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, size_t n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += predlen; 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map<size_t, size_t> count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 0; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 0; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <Python.h> 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_criterion import FairseqCriterion 12 | 13 | 14 | CRITERION_REGISTRY = {} 15 | CRITERION_CLASS_NAMES = set() 16 | 17 | 18 | def build_criterion(args, task): 19 | return CRITERION_REGISTRY[args.criterion].build_criterion(args, task) 20 | 21 | 22 | def register_criterion(name): 23 | """Decorator to register a new criterion.""" 24 | 25 | def register_criterion_cls(cls): 26 | if name in CRITERION_REGISTRY: 27 | raise ValueError('Cannot register duplicate criterion ({})'.format(name)) 28 | if not issubclass(cls, FairseqCriterion): 29 | raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__)) 30 | if cls.__name__ in CRITERION_CLASS_NAMES: 31 | # We use the criterion class name as a unique identifier in 32 | # checkpoints, so all criterions must have unique class names. 33 | raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__)) 34 | CRITERION_REGISTRY[name] = cls 35 | CRITERION_CLASS_NAMES.add(cls.__name__) 36 | return cls 37 | 38 | return register_criterion_cls 39 | 40 | 41 | # automatically import any Python files in the criterions/ directory 42 | for file in os.listdir(os.path.dirname(__file__)): 43 | if file.endswith('.py') and not file.startswith('_'): 44 | module = file[:file.find('.py')] 45 | importlib.import_module('fairseq.criterions.' + module) 46 | -------------------------------------------------------------------------------- /fairseq/criterions/adaptive_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import math 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('adaptive_loss') 17 | class AdaptiveLoss(FairseqCriterion): 18 | """This is an implementation of the loss function accompanying the adaptive softmax approximation for 19 | graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" 20 | (http://arxiv.org/abs/1609.04309).""" 21 | 22 | def __init__(self, args, task): 23 | super().__init__(args, task) 24 | 25 | if args.ddp_backend == 'c10d': 26 | raise Exception( 27 | 'AdaptiveLoss is not compatible with the c10d ' 28 | 'version of DistributedDataParallel. Please use ' 29 | '`--ddp-backend=no_c10d` instead.' 30 | ) 31 | 32 | def forward(self, model, sample, reduce=True): 33 | """Compute the loss for the given sample. 34 | 35 | Returns a tuple with three elements: 36 | 1) the loss 37 | 2) the sample size, which is used as the denominator for the gradient 38 | 3) logging outputs to display while training 39 | """ 40 | 41 | assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None 42 | adaptive_softmax = model.decoder.adaptive_softmax 43 | 44 | net_output = model(**sample['net_input']) 45 | orig_target = model.get_targets(sample, net_output) 46 | 47 | nsentences = orig_target.size(0) 48 | orig_target = orig_target.view(-1) 49 | 50 | bsz = orig_target.size(0) 51 | 52 | logits, target = adaptive_softmax(net_output[0], orig_target) 53 | assert len(target) == len(logits) 54 | 55 | loss = net_output[0].new(1 if reduce else bsz).zero_() 56 | 57 | for i in range(len(target)): 58 | if target[i] is not None: 59 | assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1)) 60 | loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx, 61 | reduce=reduce) 62 | 63 | orig = utils.strip_pad(orig_target, self.padding_idx) 64 | ntokens = orig.numel() 65 | sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens 66 | logging_output = { 67 | 'loss': utils.item(loss.data) if reduce else loss.data, 68 | 'ntokens': ntokens, 69 | 'nsentences': nsentences, 70 | 'sample_size': sample_size, 71 | } 72 | return loss, sample_size, logging_output 73 | 74 | @staticmethod 75 | def aggregate_logging_outputs(logging_outputs): 76 | """Aggregate logging outputs from data parallel training.""" 77 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 78 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 79 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 80 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 81 | agg_output = { 82 | 'loss': loss_sum / sample_size / math.log(2), 83 | 'nll_loss': loss_sum / sample_size / math.log(2), 84 | 'ntokens': ntokens, 85 | 'nsentences': nsentences, 86 | 'sample_size': sample_size, 87 | } 88 | if sample_size != ntokens: 89 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 90 | return agg_output 91 | -------------------------------------------------------------------------------- /fairseq/criterions/composite_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch import nn 9 | 10 | from fairseq import utils 11 | from . import FairseqCriterion, register_criterion 12 | 13 | 14 | @register_criterion('composite_loss') 15 | class CompositeLoss(FairseqCriterion): 16 | """This is a composite loss that, given a list of model outputs and a list of targets, 17 | computes an average of losses for each output-target pair""" 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add criterion-specific arguments to the parser.""" 22 | # fmt: off 23 | parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, 24 | help='underlying criterion to use for the composite loss') 25 | # fmt: on 26 | 27 | @staticmethod 28 | def build_underlying_criterion(args, task): 29 | saved_criterion = args.criterion 30 | args.criterion = args.underlying_criterion 31 | assert saved_criterion != args.underlying_criterion 32 | underlying_criterion = task.build_criterion(args) 33 | args.criterion = saved_criterion 34 | return underlying_criterion 35 | 36 | @classmethod 37 | def build_criterion(cls, args, task): 38 | underlying_criterion = CompositeLoss.build_underlying_criterion(args, task) 39 | 40 | class FakeModel(nn.Module): 41 | 42 | def __init__(self, model, net_out, target): 43 | super().__init__() 44 | self.model = model 45 | self.net_out = net_out 46 | self.target = target 47 | 48 | def forward(self, **unused): 49 | return self.net_out 50 | 51 | def get_normalized_probs(self, net_output, log_probs, sample=None): 52 | return self.model.get_normalized_probs(net_output, log_probs, sample=sample) 53 | 54 | def get_targets(self, *unused): 55 | return self.target 56 | 57 | @property 58 | def decoder(self): 59 | return self.model.decoder 60 | 61 | class _CompositeLoss(FairseqCriterion): 62 | 63 | def __init__(self, args, task, underlying_criterion): 64 | super().__init__(args, task) 65 | self.underlying_criterion = underlying_criterion 66 | 67 | def forward(self, model, sample, reduce=True): 68 | net_outputs = model(**sample['net_input']) 69 | targets = sample['target'] 70 | 71 | bsz = targets[0].size(0) 72 | loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_() 73 | 74 | sample_size = 0 75 | logging_output = {} 76 | for o, t in zip(net_outputs[0], targets): 77 | m = FakeModel(model, (o, net_outputs[1]), t) 78 | sample['target'] = t 79 | l, ss, logging_output = self.underlying_criterion(m, sample, reduce) 80 | loss += l 81 | sample_size += ss 82 | 83 | loss.div_(len(targets)) 84 | sample_size /= len(targets) 85 | 86 | logging_output['loss'] = utils.item(loss.data) if reduce else loss.data 87 | return loss, sample_size, logging_output 88 | 89 | @staticmethod 90 | def aggregate_logging_outputs(logging_outputs): 91 | return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs) 92 | 93 | return _CompositeLoss(args, task, underlying_criterion) 94 | -------------------------------------------------------------------------------- /fairseq/criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('cross_entropy') 17 | class CrossEntropyCriterion(FairseqCriterion): 18 | 19 | def __init__(self, args, task): 20 | super().__init__(args, task) 21 | 22 | def forward(self, model, sample, reduce=True): 23 | """Compute the loss for the given sample. 24 | 25 | Returns a tuple with three elements: 26 | 1) the loss 27 | 2) the sample size, which is used as the denominator for the gradient 28 | 3) logging outputs to display while training 29 | """ 30 | net_output = model(**sample['net_input']) 31 | loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) 32 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 33 | logging_output = { 34 | 'loss': utils.item(loss.data) if reduce else loss.data, 35 | 'ntokens': sample['ntokens'], 36 | 'nsentences': sample['target'].size(0), 37 | 'sample_size': sample_size, 38 | } 39 | return loss, sample_size, logging_output 40 | 41 | def compute_loss(self, model, net_output, sample, reduce=True): 42 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 43 | lprobs = lprobs.view(-1, lprobs.size(-1)) 44 | target = model.get_targets(sample, net_output).view(-1) 45 | loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, 46 | reduce=reduce) 47 | return loss, loss 48 | 49 | @staticmethod 50 | def aggregate_logging_outputs(logging_outputs): 51 | """Aggregate logging outputs from data parallel training.""" 52 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 53 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 54 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 55 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 56 | agg_output = { 57 | 'loss': loss_sum / sample_size / math.log(2), 58 | 'ntokens': ntokens, 59 | 'nsentences': nsentences, 60 | 'sample_size': sample_size, 61 | } 62 | if sample_size != ntokens: 63 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 64 | return agg_output 65 | -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.nn.modules.loss import _Loss 9 | 10 | 11 | class FairseqCriterion(_Loss): 12 | 13 | def __init__(self, args, task): 14 | super().__init__() 15 | self.args = args 16 | self.padding_idx = task.target_dictionary.pad() 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add criterion-specific arguments to the parser.""" 21 | pass 22 | 23 | @classmethod 24 | def build_criterion(cls, args, task): 25 | return cls(args, task) 26 | 27 | def forward(self, model, sample, reduce=True): 28 | """Compute the loss for the given sample. 29 | 30 | Returns a tuple with three elements: 31 | 1) the loss 32 | 2) the sample size, which is used as the denominator for the gradient 33 | 3) logging outputs to display while training 34 | """ 35 | raise NotImplementedError 36 | 37 | @staticmethod 38 | def aggregate_logging_outputs(logging_outputs): 39 | """Aggregate logging outputs from data parallel training.""" 40 | raise NotImplementedError 41 | 42 | @staticmethod 43 | def grad_denom(sample_sizes): 44 | """Compute the gradient denominator for a set of sample sizes.""" 45 | return sum(sample_sizes) 46 | -------------------------------------------------------------------------------- /fairseq/criterions/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from fairseq import utils 11 | 12 | from . import FairseqCriterion, register_criterion 13 | 14 | 15 | @register_criterion('label_smoothed_cross_entropy') 16 | class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): 17 | 18 | def __init__(self, args, task): 19 | super().__init__(args, task) 20 | self.eps = args.label_smoothing 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add criterion-specific arguments to the parser.""" 25 | # fmt: off 26 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 27 | help='epsilon for label smoothing, 0 means no label smoothing') 28 | # fmt: on 29 | 30 | def forward(self, model, sample, reduce=True): 31 | """Compute the loss for the given sample. 32 | 33 | Returns a tuple with three elements: 34 | 1) the loss 35 | 2) the sample size, which is used as the denominator for the gradient 36 | 3) logging outputs to display while training 37 | """ 38 | net_output = model(**sample['net_input']) 39 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 40 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 41 | logging_output = { 42 | 'loss': utils.item(loss.data) if reduce else loss.data, 43 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 44 | 'ntokens': sample['ntokens'], 45 | 'nsentences': sample['target'].size(0), 46 | 'sample_size': sample_size, 47 | } 48 | return loss, sample_size, logging_output 49 | 50 | def compute_loss(self, model, net_output, sample, reduce=True): 51 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 52 | lprobs = lprobs.view(-1, lprobs.size(-1)) 53 | target = model.get_targets(sample, net_output).view(-1, 1) 54 | non_pad_mask = target.ne(self.padding_idx) 55 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 56 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 57 | if reduce: 58 | nll_loss = nll_loss.sum() 59 | smooth_loss = smooth_loss.sum() 60 | eps_i = self.eps / lprobs.size(-1) 61 | loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss 62 | return loss, nll_loss 63 | 64 | @staticmethod 65 | def aggregate_logging_outputs(logging_outputs): 66 | """Aggregate logging outputs from data parallel training.""" 67 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 68 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 69 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 70 | return { 71 | 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 72 | 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 73 | 'ntokens': ntokens, 74 | 'nsentences': nsentences, 75 | 'sample_size': sample_size, 76 | } 77 | -------------------------------------------------------------------------------- /fairseq/criterions/regularization_label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from fairseq import utils 11 | 12 | from . import FairseqCriterion, register_criterion, label_smoothed_cross_entropy 13 | 14 | 15 | @register_criterion('regularization_label_smoothed_cross_entropy') 16 | class RegularizationCrossEntropyCriterion(label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion): 17 | 18 | def __init__(self, args, task): 19 | super().__init__(args, task) 20 | 21 | 22 | 23 | def forward(self, model, sample, reduce=True): 24 | """Compute the loss for the given sample. 25 | 26 | Returns a tuple with three elements: 27 | 1) the loss 28 | 2) the sample size, which is used as the denominator for the gradient 29 | 3) logging outputs to display while training 30 | """ 31 | net_output = model(**sample['net_input']) 32 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 33 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 34 | logging_output = { 35 | 'loss': utils.item(loss.data) if reduce else loss.data, 36 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 37 | 'ntokens': sample['ntokens'], 38 | 'nsentences': sample['target'].size(0), 39 | 'sample_size': sample_size, 40 | } 41 | return loss, sample_size, logging_output 42 | 43 | def compute_loss(self, model, net_output, sample, reduce=True): 44 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 45 | lprobs = lprobs.view(-1, lprobs.size(-1)) 46 | target = model.get_targets(sample, net_output).view(-1, 1) 47 | non_pad_mask = target.ne(self.padding_idx) 48 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 49 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 50 | reg_loss = net_output[1]['reg'].type_as(nll_loss) 51 | if reduce: 52 | nll_loss = nll_loss.sum() 53 | smooth_loss = smooth_loss.sum() 54 | eps_i = self.eps / lprobs.size(-1) 55 | loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss 56 | loss = loss + reg_loss 57 | return loss, nll_loss 58 | -------------------------------------------------------------------------------- /fairseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .dictionary import Dictionary, TruncatedDictionary 9 | from .fairseq_dataset import FairseqDataset 10 | from .backtranslation_dataset import BacktranslationDataset 11 | from .concat_dataset import ConcatDataset 12 | from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset 13 | from .language_pair_dataset import LanguagePairDataset 14 | from .lm_context_window_dataset import LMContextWindowDataset 15 | from .monolingual_dataset import MonolingualDataset 16 | from .round_robin_zip_datasets import RoundRobinZipDatasets 17 | from .token_block_dataset import TokenBlockDataset 18 | from .transform_eos_dataset import TransformEosDataset 19 | 20 | from .iterators import ( 21 | CountingIterator, 22 | EpochBatchIterator, 23 | GroupedIterator, 24 | ShardedIterator, 25 | ) 26 | 27 | __all__ = [ 28 | 'BacktranslationDataset', 29 | 'ConcatDataset', 30 | 'CountingIterator', 31 | 'Dictionary', 32 | 'EpochBatchIterator', 33 | 'FairseqDataset', 34 | 'GroupedIterator', 35 | 'IndexedCachedDataset', 36 | 'IndexedDataset', 37 | 'IndexedRawTextDataset', 38 | 'LanguagePairDataset', 39 | 'LMContextWindowDataset', 40 | 'MonolingualDataset', 41 | 'RoundRobinZipDatasets', 42 | 'ShardedIterator', 43 | 'TokenBlockDataset', 44 | 'TransformEosDataset', 45 | ] 46 | -------------------------------------------------------------------------------- /fairseq/data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import bisect 9 | 10 | import numpy as np 11 | from . import FairseqDataset 12 | 13 | 14 | class ConcatDataset(FairseqDataset): 15 | 16 | @staticmethod 17 | def cumsum(sequence, sample_ratios): 18 | r, s = [], 0 19 | for e, ratio in zip(sequence, sample_ratios): 20 | l = ratio * len(e) 21 | r.append(l + s) 22 | s += l 23 | return r 24 | 25 | def __init__(self, datasets, sample_ratios=1): 26 | super(ConcatDataset, self).__init__() 27 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 28 | self.datasets = list(datasets) 29 | if isinstance(sample_ratios, int): 30 | sample_ratios = [sample_ratios] * len(self.datasets) 31 | self.sample_ratios = sample_ratios 32 | self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) 33 | self.real_sizes = [len(d) for d in self.datasets] 34 | 35 | def __len__(self): 36 | return self.cumulative_sizes[-1] 37 | 38 | def __getitem__(self, idx): 39 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 40 | if dataset_idx == 0: 41 | sample_idx = idx 42 | else: 43 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 44 | sample_idx = sample_idx % self.real_sizes[dataset_idx] 45 | return self.datasets[dataset_idx][sample_idx] 46 | 47 | @property 48 | def sizes(self): 49 | return np.concatenate([np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)]) 50 | 51 | @property 52 | def supports_prefetch(self): 53 | return all([d.supports_prefetch for d in self.datasets]) 54 | 55 | def prefetch(self, indices): 56 | frm = 0 57 | for to, ds in zip(self.cumulative_sizes, self.datasets): 58 | real_size = len(ds) 59 | ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) 60 | frm = to 61 | -------------------------------------------------------------------------------- /fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.utils.data 9 | 10 | 11 | class FairseqDataset(torch.utils.data.Dataset): 12 | """A dataset that provides helpers for batching.""" 13 | 14 | def __getitem__(self, index): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | def collater(self, samples): 21 | """Merge a list of samples to form a mini-batch. 22 | 23 | Args: 24 | samples (List[int]): sample indices to collate 25 | 26 | Returns: 27 | dict: a mini-batch suitable for forwarding with a Model 28 | """ 29 | raise NotImplementedError 30 | 31 | def get_dummy_batch(self, num_tokens, max_positions): 32 | """Return a dummy batch with a given number of tokens.""" 33 | raise NotImplementedError 34 | 35 | def num_tokens(self, index): 36 | """Return the number of tokens in a sample. This value is used to 37 | enforce ``--max-tokens`` during batching.""" 38 | raise NotImplementedError 39 | 40 | def size(self, index): 41 | """Return an example's size as a float or tuple. This value is used when 42 | filtering a dataset with ``--max-positions``.""" 43 | raise NotImplementedError 44 | 45 | def ordered_indices(self): 46 | """Return an ordered list of indices. Batches will be constructed based 47 | on this order.""" 48 | raise NotImplementedError 49 | 50 | @property 51 | def supports_prefetch(self): 52 | """Whether this dataset supports prefetching.""" 53 | return False 54 | 55 | def prefetch(self, indices): 56 | """Prefetch the data required for this epoch.""" 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /fairseq/data/lm_context_window_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from fairseq.data.monolingual_dataset import MonolingualDataset 12 | 13 | from . import FairseqDataset 14 | 15 | 16 | class LMContextWindowDataset(FairseqDataset): 17 | """Wraps a MonolingualDataset and provides more context for evaluation.""" 18 | 19 | def __init__(self, dataset, tokens_per_sample, context_window, pad_idx): 20 | assert isinstance(dataset, MonolingualDataset) 21 | assert context_window > 0 22 | self.dataset = dataset 23 | self.tokens_per_sample = tokens_per_sample 24 | self.context_window = context_window 25 | self.pad_idx = pad_idx 26 | self.prev_tokens = np.empty([0]) 27 | 28 | def __getitem__(self, index): 29 | return self.dataset[index] 30 | 31 | def __len__(self): 32 | return len(self.dataset) 33 | 34 | def collater(self, samples): 35 | sample = self.dataset.collater(samples) 36 | 37 | pad = self.pad_idx 38 | max_sample_len = self.tokens_per_sample + self.context_window 39 | 40 | bsz, tsz = sample['net_input']['src_tokens'].shape 41 | start_idxs = [0] * bsz 42 | toks = sample['net_input']['src_tokens'] 43 | lengths = sample['net_input']['src_lengths'] 44 | tgt = sample['target'] 45 | new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64) 46 | new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64) 47 | sample_lens = toks.ne(pad).long().sum(dim=1).cpu() 48 | for i in range(bsz): 49 | sample_len = sample_lens[i] 50 | extra = len(self.prev_tokens) + sample_len - max_sample_len 51 | if extra > 0: 52 | self.prev_tokens = self.prev_tokens[extra:] 53 | pads = np.full(self.context_window - len(self.prev_tokens), pad) 54 | new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads]) 55 | new_tgt[i, len(self.prev_tokens):len(self.prev_tokens) + len(tgt[i])] = tgt[i] 56 | start_idxs[i] = len(self.prev_tokens) 57 | lengths[i] += len(self.prev_tokens) 58 | self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window:] 59 | sample['net_input']['src_tokens'] = torch.from_numpy(new_toks) 60 | sample['target'] = torch.from_numpy(new_tgt) 61 | sample['start_indices'] = start_idxs 62 | 63 | return sample 64 | 65 | def get_dummy_batch(self, *args, **kwargs): 66 | return self.dataset.get_dummy_batch(*args, **kwargs) 67 | 68 | def num_tokens(self, index): 69 | return self.dataset.num_tokens(index) 70 | 71 | def size(self, index): 72 | return self.dataset.size(index) 73 | 74 | def ordered_indices(self): 75 | # NOTE we don't shuffle the data to retain access to the previous dataset elements 76 | return np.arange(len(self.dataset)) 77 | 78 | @property 79 | def supports_prefetch(self): 80 | return getattr(self.dataset, 'supports_prefetch', False) 81 | 82 | def prefetch(self, indices): 83 | return self.dataset.prefetch(indices) 84 | -------------------------------------------------------------------------------- /fairseq/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import time 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | class TimeMeter(object): 30 | """Computes the average occurrence of some event per second""" 31 | def __init__(self, init=0): 32 | self.reset(init) 33 | 34 | def reset(self, init=0): 35 | self.init = init 36 | self.start = time.time() 37 | self.n = 0 38 | 39 | def update(self, val=1): 40 | self.n += val 41 | 42 | @property 43 | def avg(self): 44 | return self.n / self.elapsed_time 45 | 46 | @property 47 | def elapsed_time(self): 48 | return self.init + (time.time() - self.start) 49 | 50 | 51 | class StopwatchMeter(object): 52 | """Computes the sum/avg duration of some event in seconds""" 53 | def __init__(self): 54 | self.reset() 55 | 56 | def start(self): 57 | self.start_time = time.time() 58 | 59 | def stop(self, n=1): 60 | if self.start_time is not None: 61 | delta = time.time() - self.start_time 62 | self.sum += delta 63 | self.n += n 64 | self.start_time = None 65 | 66 | def reset(self): 67 | self.sum = 0 68 | self.n = 0 69 | self.start_time = None 70 | 71 | @property 72 | def avg(self): 73 | return self.sum / self.n 74 | -------------------------------------------------------------------------------- /fairseq/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/fairseq/models/.DS_Store -------------------------------------------------------------------------------- /fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqEncoder 9 | 10 | 11 | class CompositeEncoder(FairseqEncoder): 12 | """ 13 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 14 | 15 | We run forward on each encoder and return a dictionary of outputs. The first 16 | encoder's dictionary is used for initialization. 17 | 18 | Args: 19 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 20 | """ 21 | 22 | def __init__(self, encoders): 23 | super().__init__(next(iter(encoders.values())).dictionary) 24 | self.encoders = encoders 25 | for key in self.encoders: 26 | self.add_module(key, self.encoders[key]) 27 | 28 | def forward(self, src_tokens, src_lengths): 29 | """ 30 | Args: 31 | src_tokens (LongTensor): tokens in the source language of shape 32 | `(batch, src_len)` 33 | src_lengths (LongTensor): lengths of each source sentence of shape 34 | `(batch)` 35 | 36 | Returns: 37 | dict: 38 | the outputs from each Encoder 39 | """ 40 | encoder_out = {} 41 | for key in self.encoders: 42 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 43 | return encoder_out 44 | 45 | def reorder_encoder_out(self, encoder_out, new_order): 46 | """Reorder encoder output according to new_order.""" 47 | for key in self.encoders: 48 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) 49 | return encoder_out 50 | 51 | def max_positions(self): 52 | return min([self.encoders[key].max_positions() for key in self.encoders]) 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | for key in self.encoders: 56 | self.encoders[key].upgrade_state_dict(state_dict) 57 | return state_dict 58 | -------------------------------------------------------------------------------- /fairseq/models/distributed_fairseq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import inspect 9 | from torch.nn import parallel 10 | 11 | from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel 12 | 13 | from . import BaseFairseqModel 14 | 15 | 16 | def DistributedFairseqModel(args, model): 17 | """ 18 | Wrap a *model* to support distributed data parallel training. 19 | 20 | This is similar to the built-in DistributedDataParallel, but allows 21 | additional configuration of the DistributedDataParallel class to 22 | use, and also provides easier access to the wrapped model by 23 | forwarding requests for missing attributes to the wrapped model. 24 | 25 | Args: 26 | args (argparse.Namespace): fairseq args 27 | model (BaseFairseqModel): model to wrap 28 | """ 29 | 30 | # determine which DDP class to extend 31 | assert isinstance(model, BaseFairseqModel) 32 | if args.ddp_backend == 'c10d': 33 | ddp_class = parallel.DistributedDataParallel 34 | init_kwargs = dict( 35 | module=model, 36 | device_ids=[args.device_id], 37 | output_device=args.device_id, 38 | broadcast_buffers=False, 39 | bucket_cap_mb=args.bucket_cap_mb, 40 | ) 41 | # Maintain backward compatibility for 0.4 or earlier 42 | if 'check_reduction' in inspect.getargspec(ddp_class)[0]: 43 | init_kwargs['check_reduction'] = True 44 | elif args.ddp_backend == 'no_c10d': 45 | ddp_class = LegacyDistributedDataParallel 46 | init_kwargs = dict( 47 | module=model, 48 | world_size=args.distributed_world_size, 49 | buffer_size=2**28, 50 | ) 51 | else: 52 | raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) 53 | 54 | class _DistributedFairseqModel(ddp_class): 55 | """Extend DistributedDataParallel to check for missing 56 | attributes in the wrapped module.""" 57 | 58 | def __init__(self, *args, **kwargs): 59 | super().__init__(*args, **kwargs) 60 | 61 | def __getattr__(self, name): 62 | wrapped_module = super().__getattr__('module') 63 | if hasattr(wrapped_module, name): 64 | return getattr(wrapped_module, name) 65 | return super().__getattr__(name) 66 | 67 | return _DistributedFairseqModel(**init_kwargs) 68 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class FairseqDecoder(nn.Module): 14 | """Base class for decoders.""" 15 | 16 | def __init__(self, dictionary): 17 | super().__init__() 18 | self.dictionary = dictionary 19 | self.onnx_trace = False 20 | 21 | def forward(self, prev_output_tokens, encoder_out): 22 | """ 23 | Args: 24 | prev_output_tokens (LongTensor): previous decoder outputs of shape 25 | `(batch, tgt_len)`, for input feeding/teacher forcing 26 | encoder_out (Tensor, optional): output from the encoder, used for 27 | encoder-side attention 28 | 29 | Returns: 30 | tuple: 31 | - the last decoder layer's output of shape 32 | `(batch, tgt_len, vocab)` 33 | - the last decoder layer's attention weights of shape 34 | `(batch, tgt_len, src_len)` 35 | """ 36 | raise NotImplementedError 37 | 38 | def prepare_for_onnx_export_(self): 39 | self.onnx_trace = True 40 | 41 | def get_normalized_probs(self, net_output, log_probs, sample): 42 | """Get normalized probabilities (or log probs) from a net's output.""" 43 | 44 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 45 | if sample is not None: 46 | assert 'target' in sample 47 | target = sample['target'] 48 | else: 49 | target = None 50 | out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) 51 | return out.exp_() if not log_probs else out 52 | 53 | logits = net_output[0] 54 | if log_probs: 55 | return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) 56 | else: 57 | return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) 58 | 59 | def max_positions(self): 60 | """Maximum input length supported by the decoder.""" 61 | return 1e6 # an arbitrary large number 62 | 63 | def upgrade_state_dict(self, state_dict): 64 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 65 | return state_dict 66 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class FairseqEncoder(nn.Module): 12 | """Base class for encoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__() 16 | self.dictionary = dictionary 17 | 18 | def forward(self, src_tokens, src_lengths): 19 | """ 20 | Args: 21 | src_tokens (LongTensor): tokens in the source language of shape 22 | `(batch, src_len)` 23 | src_lengths (LongTensor): lengths of each source sentence of shape 24 | `(batch)` 25 | """ 26 | raise NotImplementedError 27 | 28 | def reorder_encoder_out(self, encoder_out, new_order): 29 | """ 30 | Reorder encoder output according to `new_order`. 31 | 32 | Args: 33 | encoder_out: output from the ``forward()`` method 34 | new_order (LongTensor): desired order 35 | 36 | Returns: 37 | `encoder_out` rearranged according to `new_order` 38 | """ 39 | raise NotImplementedError 40 | 41 | def max_positions(self): 42 | """Maximum input length supported by the encoder.""" 43 | return 1e6 # an arbitrary large number 44 | 45 | def upgrade_state_dict(self, state_dict): 46 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 47 | return state_dict 48 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_incremental_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqDecoder 9 | 10 | 11 | class FairseqIncrementalDecoder(FairseqDecoder): 12 | """Base class for incremental decoders. 13 | 14 | Incremental decoding is a special mode at inference time where the Model 15 | only receives a single timestep of input corresponding to the immediately 16 | previous output token (for input feeding) and must produce the next output 17 | *incrementally*. Thus the model must cache any long-term state that is 18 | needed about the sequence, e.g., hidden states, convolutional states, etc. 19 | 20 | Compared to the standard :class:`FairseqDecoder` interface, the incremental 21 | decoder interface allows :func:`forward` functions to take an extra keyword 22 | argument (*incremental_state*) that can be used to cache state across 23 | time-steps. 24 | 25 | The :class:`FairseqIncrementalDecoder` interface also defines the 26 | :func:`reorder_incremental_state` method, which is used during beam search 27 | to select and reorder the incremental state based on the selection of beams. 28 | """ 29 | 30 | def __init__(self, dictionary): 31 | super().__init__(dictionary) 32 | 33 | def forward(self, prev_output_tokens, encoder_out, incremental_state=None): 34 | """ 35 | Args: 36 | prev_output_tokens (LongTensor): previous decoder outputs of shape 37 | `(batch, tgt_len)`, for input feeding/teacher forcing 38 | encoder_out (Tensor, optional): output from the encoder, used for 39 | encoder-side attention 40 | incremental_state (dict): dictionary used for storing state during 41 | :ref:`Incremental decoding` 42 | 43 | Returns: 44 | tuple: 45 | - the last decoder layer's output of shape `(batch, tgt_len, 46 | vocab)` 47 | - the last decoder layer's attention weights of shape `(batch, 48 | tgt_len, src_len)` 49 | """ 50 | raise NotImplementedError 51 | 52 | def reorder_incremental_state(self, incremental_state, new_order): 53 | """Reorder incremental state. 54 | 55 | This should be called when the order of the input has changed from the 56 | previous time step. A typical use case is beam search, where the input 57 | order changes between time steps based on the selection of beams. 58 | """ 59 | seen = set() 60 | 61 | def apply_reorder_incremental_state(module): 62 | if module != self and hasattr(module, 'reorder_incremental_state') \ 63 | and module not in seen: 64 | seen.add(module) 65 | module.reorder_incremental_state(incremental_state, new_order) 66 | 67 | self.apply(apply_reorder_incremental_state) 68 | 69 | def set_beam_size(self, beam_size): 70 | """Sets the beam size in the decoder and all children.""" 71 | if getattr(self, '_beam_size', -1) != beam_size: 72 | seen = set() 73 | 74 | def apply_set_beam_size(module): 75 | if module != self and hasattr(module, 'set_beam_size') \ 76 | and module not in seen: 77 | seen.add(module) 78 | module.set_beam_size(beam_size) 79 | 80 | self.apply(apply_set_beam_size) 81 | self._beam_size = beam_size 82 | -------------------------------------------------------------------------------- /fairseq/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/fairseq/modules/.DS_Store -------------------------------------------------------------------------------- /fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .adaptive_input import AdaptiveInput 9 | from .adaptive_softmax import AdaptiveSoftmax 10 | from .beamable_mm import BeamableMM 11 | from .character_token_embedder import CharacterTokenEmbedder 12 | from .conv_tbc import ConvTBC 13 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 14 | from .dynamic_convolution import DynamicConv1dTBC 15 | from .grad_multiply import GradMultiply 16 | from .highway import Highway 17 | from .layer_norm import LayerNorm 18 | from .learned_positional_embedding import LearnedPositionalEmbedding 19 | from .lightweight_convolution import LightweightConv1dTBC 20 | from .linearized_convolution import LinearizedConvolution 21 | from .logsumexp_moe import LogSumExpMoE 22 | from .mean_pool_gating_network import MeanPoolGatingNetwork 23 | from .multihead_attention import MultiheadAttention 24 | from .scalar_bias import ScalarBias 25 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 26 | from .unfold1d import unfold1d 27 | from .relative_multihead_attention import RelativeMultiheadAttention 28 | 29 | __all__ = [ 30 | 'AdaptiveInput', 31 | 'AdaptiveSoftmax', 32 | 'BeamableMM', 33 | 'CharacterTokenEmbedder', 34 | 'ConvTBC', 35 | 'DownsampledMultiHeadAttention', 36 | 'DynamicConv1dTBC', 37 | 'GradMultiply', 38 | 'Highway', 39 | 'LayerNorm', 40 | 'LearnedPositionalEmbedding', 41 | 'LightweightConv1dTBC', 42 | 'LinearizedConvolution', 43 | 'LogSumExpMoE', 44 | 'MeanPoolGatingNetwork', 45 | 'MultiheadAttention', 46 | 'ScalarBias', 47 | 'SinusoidalPositionalEmbedding', 48 | 'unfold1d', 49 | 'RelativeMultiheadAttention', 50 | ] 51 | -------------------------------------------------------------------------------- /fairseq/modules/adaptive_input.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from typing import List 13 | 14 | 15 | class AdaptiveInput(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | vocab_size: int, 20 | padding_idx: int, 21 | initial_dim: int, 22 | factor: float, 23 | output_dim: int, 24 | cutoff: List[int], 25 | ): 26 | super().__init__() 27 | 28 | if vocab_size > cutoff[-1]: 29 | cutoff = cutoff + [vocab_size] 30 | else: 31 | assert vocab_size == cutoff[ 32 | -1], 'cannot specify cutoff larger than vocab size' 33 | 34 | self.cutoff = cutoff 35 | self.embedding_dim = output_dim 36 | self.padding_idx = padding_idx 37 | 38 | self.embeddings = nn.ModuleList() 39 | for i in range(len(self.cutoff)): 40 | prev = self.cutoff[i - 1] if i > 0 else 0 41 | size = self.cutoff[i] - prev 42 | dim = int(initial_dim // (factor ** i)) 43 | seq = nn.Sequential( 44 | nn.Embedding(size, dim, padding_idx), 45 | nn.Linear(dim, output_dim, bias=False) 46 | ) 47 | self.embeddings.append(seq) 48 | 49 | def init_weights(m): 50 | if isinstance(m, nn.Embedding): 51 | nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) 52 | nn.init.constant_(m.weight[padding_idx], 0) 53 | elif hasattr(m, 'weight'): 54 | nn.init.xavier_uniform_(m.weight) 55 | 56 | self.apply(init_weights) 57 | 58 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 59 | 60 | def weights_for_band(self, band: int): 61 | return self.embeddings[band][0].weight, self.embeddings[band][1].weight 62 | 63 | def forward(self, input: torch.Tensor): 64 | result = self._float_tensor.new(input.shape + (self.embedding_dim,)) 65 | for i in range(len(self.cutoff)): 66 | mask = input.lt(self.cutoff[i]) 67 | if i > 0: 68 | mask.mul_(input.ge(self.cutoff[i - 1])) 69 | chunk_input = input[mask] - self.cutoff[i - 1] 70 | else: 71 | chunk_input = input[mask] 72 | if mask.any(): 73 | result[mask] = self.embeddings[i](chunk_input) 74 | return result 75 | -------------------------------------------------------------------------------- /fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class BeamableMM(nn.Module): 13 | """This module provides an optimized MM for beam decoding with attention. 14 | 15 | It leverage the fact that the source-side of the input is replicated beam 16 | times and the target-side of the input is of width one. This layer speeds up 17 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 18 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 19 | """ 20 | def __init__(self, beam_size=None): 21 | super(BeamableMM, self).__init__() 22 | self.beam_size = beam_size 23 | 24 | def forward(self, input1, input2): 25 | if ( 26 | not self.training and # test mode 27 | self.beam_size is not None and # beam size is set 28 | input1.dim() == 3 and # only support batched input 29 | input1.size(1) == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch.nn.modules.utils import _single 10 | 11 | 12 | class ConvTBC(torch.nn.Module): 13 | """1D convolution over an input of shape (time x batch x channel) 14 | 15 | The implementation uses gemm to perform the convolution. This implementation 16 | is faster than cuDNN for small kernel sizes. 17 | """ 18 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 19 | super(ConvTBC, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _single(kernel_size) 23 | self.padding = _single(padding) 24 | 25 | self.weight = torch.nn.Parameter(torch.Tensor( 26 | self.kernel_size[0], in_channels, out_channels)) 27 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 28 | 29 | def forward(self, input): 30 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 31 | 32 | def __repr__(self): 33 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 34 | ', padding={padding}') 35 | if self.bias is None: 36 | s += ', bias=False' 37 | s += ')' 38 | return s.format(name=self.__class__.__name__, **self.__dict__) 39 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class GradMultiply(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, scale): 14 | ctx.scale = scale 15 | res = x.new(x) 16 | return res 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | return grad * ctx.scale, None 21 | -------------------------------------------------------------------------------- /fairseq/modules/highway.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from torch import nn 12 | 13 | 14 | class Highway(torch.nn.Module): 15 | """ 16 | A `Highway layer <https://arxiv.org/abs/1505.00387>`_. 17 | Adopted from the AllenNLP implementation. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int, 23 | num_layers: int = 1 24 | ): 25 | super(Highway, self).__init__() 26 | self.input_dim = input_dim 27 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) 28 | for _ in range(num_layers)]) 29 | self.activation = nn.ReLU() 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | for layer in self.layers: 35 | # As per comment in AllenNLP: 36 | # We should bias the highway layer to just carry its input forward. We do that by 37 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 38 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 39 | # of the bias vector in each Linear layer. 40 | nn.init.constant_(layer.bias[self.input_dim:], 1) 41 | 42 | nn.init.constant_(layer.bias[:self.input_dim], 0) 43 | nn.init.xavier_normal_(layer.weight) 44 | 45 | def forward( 46 | self, 47 | x: torch.Tensor 48 | ): 49 | for layer in self.layers: 50 | projection = layer(x) 51 | proj_x, gate = projection.chunk(2, dim=-1) 52 | proj_x = self.activation(proj_x) 53 | gate = F.sigmoid(gate) 54 | x = gate * x + (gate.new_tensor([1]) - gate) * proj_x 55 | return x 56 | -------------------------------------------------------------------------------- /fairseq/modules/layer_history.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fairseq.models.transformer import LayerNorm 4 | import queue 5 | import fairseq.utils as utils 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def CreateLayerHistory(args, is_encoder): 11 | history_type = args.encoder_history_type if is_encoder else args.decoder_history_type 12 | if history_type is None: 13 | return None 14 | elif history_type == "learnable_dense": 15 | return LearnableDenseLayerHistory(args, is_encoder) 16 | else: 17 | raise ValueError 18 | 19 | 20 | class BaseLayerHistory(nn.Module): 21 | 22 | def __init__(self, args, is_encoder): 23 | super(BaseLayerHistory, self).__init__() 24 | self.is_encoder = is_encoder 25 | self.normalize_before = args.encoder_normalize_before if is_encoder else args.decoder_normalize_before 26 | 27 | # the first layer (aka. embedding layer) does not have layer normalization 28 | # layers = args.encoder_layers if is_encoder else args.decoder_layers 29 | layers = len(args.k) - 1 if is_encoder else args.decoder_layers 30 | dim = args.encoder_embed_dim if is_encoder else args.decoder_embed_dim 31 | self.layer_norms = nn.ModuleList(LayerNorm(dim) for _ in range(layers)) 32 | 33 | def add(self, layer): 34 | raise NotImplemented 35 | 36 | def pop(self): 37 | raise NotImplemented 38 | 39 | def clean(self): 40 | raise NotImplemented 41 | 42 | 43 | class LearnableDenseLayerHistory(BaseLayerHistory): 44 | """ 45 | x_n = (x_1 + y_1 + y_2 + ... y_{n-1}) / n 46 | """ 47 | 48 | def __init__(self, args, is_encoder): 49 | super(LearnableDenseLayerHistory, self).__init__(args, is_encoder) 50 | self.sum = None 51 | self.count = 0 52 | # self.layer_num = 1 + (args.encoder_layers if is_encoder else args.decoder_layers) 53 | self.layer_num = len(args.k) if is_encoder else args.decoder_layers + 1 54 | self.weight = nn.Parameter(torch.Tensor(self.layer_num, self.layer_num).fill_(1.0).tril()) 55 | self.weight.data = self.weight.data / self.weight.data.sum(1, keepdim=True) 56 | 57 | # print('count:', len(list(self.named_parameters()))) 58 | # for k,v in self.named_parameters(): 59 | # print('k=%s' %k) 60 | 61 | def extra_repr(self): 62 | return 'n_layers={layer_num}, '.format(**self.__dict__) 63 | 64 | def add(self, layer): 65 | self.count += 1 66 | 67 | # first layer 68 | if self.sum is None: 69 | self.sum = layer 70 | self.layers.append(layer) 71 | return 72 | 73 | # following layer 74 | if self.normalize_before: 75 | layer = self.layer_norms[self.count - 2](layer) 76 | 77 | self.layers.append(layer) 78 | 79 | def pop(self): 80 | assert len(self.layers) > 0 81 | # print(self.weight) 82 | # layers_dropout = F.dropout(torch.stack(self.layers, 0), p=self.dense_dropout, training=self.training) 83 | # ret = (layers_dropout * self.weight[self.count -1, : self.count].view(-1, 1, 1, 1)).sum(0) 84 | ret = (torch.stack(self.layers, 0) * self.weight[self.count - 1, : self.count].view(-1, 1, 1, 1)).sum(0) 85 | if self.count == 1 or self.normalize_before: 86 | return ret 87 | return self.layer_norms[self.count - 2](ret) 88 | 89 | def clean(self): 90 | self.sum = None 91 | self.count = 0 92 | self.layers = [] 93 | 94 | def get_loss(self): 95 | return (0.5 * (self.weight.sum(1) - 1.0) ** 2).mean() 96 | 97 | def print_weight(self): 98 | print(self.weight) 99 | 100 | 101 | -------------------------------------------------------------------------------- /fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): 12 | if torch.cuda.is_available(): 13 | try: 14 | from apex.normalization import FusedLayerNorm 15 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 16 | except ImportError: 17 | pass 18 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 19 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class LearnedPositionalEmbedding(nn.Embedding): 14 | """This module learns positional embeddings up to a fixed maximum size. 15 | 16 | Padding symbols are ignored, but it is necessary to specify whether padding 17 | is added on the left side (left_pad=True) or right side (left_pad=False). 18 | """ 19 | 20 | def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): 21 | super().__init__(num_embeddings, embedding_dim, padding_idx) 22 | self.left_pad = left_pad 23 | self.onnx_trace = False 24 | 25 | def forward(self, input, incremental_state=None): 26 | """Input is expected to be of size [bsz x seqlen].""" 27 | if incremental_state is not None: 28 | # positions is the same for every token when decoding a single step 29 | positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) 30 | else: 31 | positions = utils.make_positions(input.data, self.padding_idx, self.left_pad, self.onnx_trace) 32 | return super().forward(positions) 33 | 34 | def max_positions(self): 35 | """Maximum number of supported positions.""" 36 | return self.num_embeddings - self.padding_idx - 1 37 | -------------------------------------------------------------------------------- /fairseq/modules/linearized_convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from .conv_tbc import ConvTBC 14 | 15 | 16 | class LinearizedConvolution(ConvTBC): 17 | """An optimized version of nn.Conv1d. 18 | 19 | At training time, this module uses ConvTBC, which is an optimized version 20 | of Conv1d. At inference time, it optimizes incremental generation (i.e., 21 | one time step at a time) by replacing the convolutions with linear layers. 22 | Note that the input order changes from training to inference. 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 26 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 27 | self._linearized_weight = None 28 | self.register_backward_hook(self._clear_linearized_weight) 29 | 30 | def forward(self, input, incremental_state=None): 31 | """ 32 | Args: 33 | incremental_state: Used to buffer signal; if not None, then input is 34 | expected to contain a single frame. If the input order changes 35 | between time steps, call reorder_incremental_state. 36 | Input: 37 | Time x Batch x Channel during training 38 | Batch x Time x Channel during inference 39 | """ 40 | if incremental_state is None: 41 | output = super().forward(input) 42 | if self.kernel_size[0] > 1 and self.padding[0] > 0: 43 | # remove future timesteps added by padding 44 | output = output[:-self.padding[0], :, :] 45 | return output 46 | 47 | # reshape weight 48 | weight = self._get_linearized_weight() 49 | kw = self.kernel_size[0] 50 | 51 | bsz = input.size(0) # input: bsz x len x dim 52 | if kw > 1: 53 | input = input.data 54 | input_buffer = self._get_input_buffer(incremental_state) 55 | if input_buffer is None: 56 | input_buffer = input.new(bsz, kw, input.size(2)).zero_() 57 | self._set_input_buffer(incremental_state, input_buffer) 58 | else: 59 | # shift buffer 60 | input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() 61 | # append next input 62 | input_buffer[:, -1, :] = input[:, -1, :] 63 | input = input_buffer 64 | with torch.no_grad(): 65 | output = F.linear(input.view(bsz, -1), weight, self.bias) 66 | return output.view(bsz, 1, -1) 67 | 68 | def reorder_incremental_state(self, incremental_state, new_order): 69 | input_buffer = self._get_input_buffer(incremental_state) 70 | if input_buffer is not None: 71 | input_buffer = input_buffer.index_select(0, new_order) 72 | self._set_input_buffer(incremental_state, input_buffer) 73 | 74 | def _get_input_buffer(self, incremental_state): 75 | return utils.get_incremental_state(self, incremental_state, 'input_buffer') 76 | 77 | def _set_input_buffer(self, incremental_state, new_buffer): 78 | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) 79 | 80 | def _get_linearized_weight(self): 81 | if self._linearized_weight is None: 82 | kw = self.kernel_size[0] 83 | weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() 84 | assert weight.size() == (self.out_channels, kw, self.in_channels) 85 | self._linearized_weight = weight.view(self.out_channels, -1) 86 | return self._linearized_weight 87 | 88 | def _clear_linearized_weight(self, *args): 89 | self._linearized_weight = None 90 | -------------------------------------------------------------------------------- /fairseq/modules/logsumexp_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class LogSumExpMoE(torch.autograd.Function): 12 | """Standard LogSumExp forward pass, but use *posterior* for the backward. 13 | 14 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" 15 | (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, logp, posterior, dim=-1): 20 | ctx.save_for_backward(posterior) 21 | ctx.dim = dim 22 | return torch.logsumexp(logp, dim=dim) 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | posterior, = ctx.saved_tensors 27 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior 28 | return grad_logp, None, None 29 | -------------------------------------------------------------------------------- /fairseq/modules/mean_pool_gating_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class MeanPoolGatingNetwork(torch.nn.Module): 13 | """A simple mean-pooling gating network for selecting experts. 14 | 15 | This module applies mean pooling over an encoder's output and returns 16 | reponsibilities for each expert. The encoder format is expected to match 17 | :class:`fairseq.models.transformer.TransformerEncoder`. 18 | """ 19 | 20 | def __init__(self, embed_dim, num_experts, dropout=None): 21 | super().__init__() 22 | self.embed_dim = embed_dim 23 | self.num_experts = num_experts 24 | 25 | self.fc1 = torch.nn.Linear(embed_dim, embed_dim) 26 | self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None 27 | self.fc2 = torch.nn.Linear(embed_dim, num_experts) 28 | 29 | def forward(self, encoder_out): 30 | if not ( 31 | isinstance(encoder_out, dict) 32 | and 'encoder_out' in encoder_out 33 | and 'encoder_padding_mask' in encoder_out 34 | and encoder_out['encoder_out'].size(2) == self.embed_dim 35 | ): 36 | raise ValueError('Unexpected format for encoder_out') 37 | 38 | # mean pooling over time 39 | encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T 40 | encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C 41 | if encoder_padding_mask is not None: 42 | encoder_out = encoder_out.clone() # required because of transpose above 43 | encoder_out[encoder_padding_mask] = 0 44 | ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True) 45 | x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out) 46 | else: 47 | x = torch.mean(encoder_out, dim=1) 48 | 49 | x = torch.tanh(self.fc1(x)) 50 | if self.dropout is not None: 51 | x = self.dropout(x) 52 | x = self.fc2(x) 53 | return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x) 54 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | import torch 10 | 11 | 12 | class ScalarBias(torch.autograd.Function): 13 | """ 14 | Adds a vector of scalars, used in self-attention mechanism to allow 15 | the model to optionally attend to this vector instead of the past 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, input, dim, bias_init): 20 | size = list(input.size()) 21 | size[dim] += 1 22 | output = input.new(*size).fill_(bias_init) 23 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 24 | ctx.dim = dim 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad): 29 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 30 | 31 | 32 | def scalar_bias(input, dim, bias_init=0): 33 | return ScalarBias.apply(input, dim, bias_init) 34 | -------------------------------------------------------------------------------- /fairseq/modules/sinusoidal_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.onnx.operators 13 | 14 | from fairseq import utils 15 | 16 | 17 | class SinusoidalPositionalEmbedding(nn.Module): 18 | """This module produces sinusoidal positional embeddings of any length. 19 | 20 | Padding symbols are ignored, but it is necessary to specify whether padding 21 | is added on the left side (left_pad=True) or right side (left_pad=False). 22 | """ 23 | 24 | def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024): 25 | super().__init__() 26 | self.embedding_dim = embedding_dim 27 | self.padding_idx = padding_idx 28 | self.left_pad = left_pad 29 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 30 | init_size, 31 | embedding_dim, 32 | padding_idx, 33 | ) 34 | self.onnx_trace = False 35 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 36 | 37 | def prepare_for_onnx_export_(self): 38 | self.onnx_trace = True 39 | 40 | @staticmethod 41 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 42 | """Build sinusoidal embeddings. 43 | 44 | This matches the implementation in tensor2tensor, but differs slightly 45 | from the description in Section 3.5 of "Attention Is All You Need". 46 | """ 47 | half_dim = embedding_dim // 2 48 | emb = math.log(10000) / (half_dim - 1) 49 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 50 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 51 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 52 | if embedding_dim % 2 == 1: 53 | # zero pad 54 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 55 | if padding_idx is not None: 56 | emb[padding_idx, :] = 0 57 | return emb 58 | 59 | def forward(self, input, incremental_state=None, timestep=None): 60 | """Input is expected to be of size [bsz x seqlen].""" 61 | bsz, seq_len = torch.onnx.operators.shape_as_tensor(input) 62 | max_pos = self.padding_idx + 1 + seq_len 63 | if self.weights is None or max_pos > self.weights.size(0): 64 | # recompute/expand embeddings if needed 65 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 66 | max_pos, 67 | self.embedding_dim, 68 | self.padding_idx, 69 | ) 70 | self.weights = self.weights.type_as(self._float_tensor) 71 | 72 | if incremental_state is not None: 73 | # positions is the same for every token when decoding a single step 74 | pos = (timestep.int() + 1).long() if timestep is not None else seq_len 75 | if self.onnx_trace: 76 | return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1) 77 | return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) 78 | 79 | positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace) 80 | if self.onnx_trace: 81 | flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) 82 | embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1]))) 83 | embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape) 84 | return embeddings 85 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 86 | 87 | def max_positions(self): 88 | """Maximum number of supported positions.""" 89 | return int(1e5) # an arbitrary large number 90 | -------------------------------------------------------------------------------- /fairseq/modules/unfold1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 12 | '''unfold T x B x C to T x B x C x K''' 13 | if kernel_size > 1: 14 | T, B, C = x.size() 15 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 16 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 17 | else: 18 | x = x.unsqueeze(3) 19 | return x 20 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_optimizer import FairseqOptimizer 12 | from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 13 | 14 | 15 | OPTIMIZER_REGISTRY = {} 16 | OPTIMIZER_CLASS_NAMES = set() 17 | 18 | 19 | def build_optimizer(args, params): 20 | params = list(filter(lambda p: p.requires_grad, params)) 21 | return OPTIMIZER_REGISTRY[args.optimizer](args, params) 22 | 23 | 24 | def register_optimizer(name): 25 | """Decorator to register a new optimizer.""" 26 | 27 | def register_optimizer_cls(cls): 28 | if name in OPTIMIZER_REGISTRY: 29 | raise ValueError('Cannot register duplicate optimizer ({})'.format(name)) 30 | if not issubclass(cls, FairseqOptimizer): 31 | raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__)) 32 | if cls.__name__ in OPTIMIZER_CLASS_NAMES: 33 | # We use the optimizer class name as a unique identifier in 34 | # checkpoints, so all optimizer must have unique class names. 35 | raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__)) 36 | OPTIMIZER_REGISTRY[name] = cls 37 | OPTIMIZER_CLASS_NAMES.add(cls.__name__) 38 | return cls 39 | 40 | return register_optimizer_cls 41 | 42 | 43 | # automatically import any Python files in the optim/ directory 44 | for file in os.listdir(os.path.dirname(__file__)): 45 | if file.endswith('.py') and not file.startswith('_'): 46 | module = file[:file.find('.py')] 47 | importlib.import_module('fairseq.optim.' + module) 48 | -------------------------------------------------------------------------------- /fairseq/optim/adadelta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adadelta') 14 | class Adadelta(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add optimizer-specific arguments to the parser.""" 22 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', 23 | help='coefficient used for computing a running average of squared gradients') 24 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', 25 | help='term added to the denominator to improve numerical stability') 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | 'lr': self.args.lr[0], 37 | 'rho': self.args.adadelta_rho, 38 | 'eps': self.args.adadelta_eps, 39 | 'weight_decay': self.args.weight_decay, 40 | } 41 | -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adagrad') 14 | class Adagrad(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'weight_decay': self.args.weight_decay, 30 | } 31 | -------------------------------------------------------------------------------- /fairseq/optim/fairseq_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import torch 11 | 12 | 13 | class FairseqOptimizer(object): 14 | 15 | def __init__(self, args, params): 16 | super().__init__() 17 | self.args = args 18 | self.params = list(params) 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | """Add optimizer-specific arguments to the parser.""" 23 | pass 24 | 25 | @property 26 | def optimizer(self): 27 | """Return a torch.optim.optimizer.Optimizer instance.""" 28 | if not hasattr(self, '_optimizer'): 29 | raise NotImplementedError 30 | if not isinstance(self._optimizer, torch.optim.Optimizer): 31 | raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') 32 | return self._optimizer 33 | 34 | @property 35 | def optimizer_config(self): 36 | """ 37 | Return a kwarg dictionary that will be used to override optimizer 38 | args stored in checkpoints. This allows us to load a checkpoint and 39 | resume training using a different set of optimizer args, e.g., with a 40 | different learning rate. 41 | """ 42 | raise NotImplementedError 43 | 44 | def get_lr(self): 45 | """Return the current learning rate.""" 46 | return self.optimizer.param_groups[0]['lr'] 47 | 48 | def set_lr(self, lr): 49 | """Set the learning rate.""" 50 | for param_group in self.optimizer.param_groups: 51 | param_group['lr'] = lr 52 | 53 | def state_dict(self): 54 | """Return the optimizer's state dict.""" 55 | return self.optimizer.state_dict() 56 | 57 | def load_state_dict(self, state_dict, optimizer_overrides=None): 58 | """Load an optimizer state dict. 59 | 60 | In general we should prefer the configuration of the existing optimizer 61 | instance (e.g., learning rate) over that found in the state_dict. This 62 | allows us to resume training from a checkpoint using a new set of 63 | optimizer args. 64 | """ 65 | self.optimizer.load_state_dict(state_dict) 66 | 67 | if optimizer_overrides is not None and len(optimizer_overrides) > 0: 68 | # override learning rate, momentum, etc. with latest values 69 | for group in self.optimizer.param_groups: 70 | group.update(optimizer_overrides) 71 | 72 | def backward(self, loss): 73 | """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" 74 | loss.backward() 75 | 76 | def multiply_grads(self, c): 77 | """Multiplies grads by a constant *c*.""" 78 | for p in self.params: 79 | if p.grad is not None: 80 | p.grad.data.mul_(c) 81 | 82 | def clip_grad_norm(self, max_norm): 83 | """Clips gradient norm.""" 84 | if max_norm > 0: 85 | return torch.nn.utils.clip_grad_norm_(self.params, max_norm) 86 | else: 87 | return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params if p.grad is not None)) 88 | 89 | def step(self, closure=None): 90 | """Performs a single optimization step.""" 91 | self.optimizer.step(closure) 92 | 93 | def zero_grad(self): 94 | """Clears the gradients of all optimized parameters.""" 95 | for group in self.optimizer.param_groups: 96 | for p in group['params']: 97 | p.grad = None 98 | self.optimizer.zero_grad() 99 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_lr_scheduler import FairseqLRScheduler 12 | 13 | 14 | LR_SCHEDULER_REGISTRY = {} 15 | 16 | 17 | def build_lr_scheduler(args, optimizer): 18 | return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer) 19 | 20 | 21 | def register_lr_scheduler(name): 22 | """Decorator to register a new LR scheduler.""" 23 | 24 | def register_lr_scheduler_cls(cls): 25 | if name in LR_SCHEDULER_REGISTRY: 26 | raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name)) 27 | if not issubclass(cls, FairseqLRScheduler): 28 | raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__)) 29 | LR_SCHEDULER_REGISTRY[name] = cls 30 | return cls 31 | 32 | return register_lr_scheduler_cls 33 | 34 | 35 | # automatically import any Python files in the optim/lr_scheduler/ directory 36 | for file in os.listdir(os.path.dirname(__file__)): 37 | if file.endswith('.py') and not file.startswith('_'): 38 | module = file[:file.find('.py')] 39 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 40 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .. import FairseqOptimizer 9 | 10 | 11 | class FairseqLRScheduler(object): 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__() 15 | if not isinstance(optimizer, FairseqOptimizer): 16 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 17 | self.args = args 18 | self.optimizer = optimizer 19 | self.best = None 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add arguments to the parser for this LR scheduler.""" 24 | pass 25 | 26 | def state_dict(self): 27 | """Return the LR scheduler state dict.""" 28 | return {'best': self.best} 29 | 30 | def load_state_dict(self, state_dict): 31 | """Load an LR scheduler state dict.""" 32 | self.best = state_dict['best'] 33 | 34 | def step(self, epoch, val_loss=None): 35 | """Update the learning rate at the end of the given epoch.""" 36 | if val_loss is not None: 37 | if self.best is None: 38 | self.best = val_loss 39 | else: 40 | self.best = min(self.best, val_loss) 41 | 42 | def step_update(self, num_updates): 43 | """Update the learning rate after each update.""" 44 | return self.optimizer.get_lr() 45 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('fixed') 12 | class FixedSchedule(FairseqLRScheduler): 13 | """Decay the LR on a fixed schedule.""" 14 | 15 | def __init__(self, args, optimizer): 16 | super().__init__(args, optimizer) 17 | 18 | # set defaults 19 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 20 | 21 | self.lr = args.lr[0] 22 | if args.warmup_updates > 0: 23 | self.warmup_factor = 1. / args.warmup_updates 24 | else: 25 | self.warmup_factor = 1 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | """Add arguments to the parser for this LR scheduler.""" 30 | # fmt: off 31 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 32 | help='force annealing at specified epoch') 33 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 34 | help='warmup the learning rate linearly for the first N updates') 35 | # fmt: on 36 | 37 | def get_next_lr(self, epoch): 38 | lrs = self.args.lr 39 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 40 | # use fixed LR schedule 41 | next_lr = lrs[min(epoch, len(lrs) - 1)] 42 | else: 43 | # annneal based on lr_shrink 44 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 45 | return next_lr 46 | 47 | def step(self, epoch, val_loss=None): 48 | """Update the learning rate at the end of the given epoch.""" 49 | super().step(epoch, val_loss) 50 | self.lr = self.get_next_lr(epoch) 51 | self.optimizer.set_lr(self.warmup_factor * self.lr) 52 | return self.optimizer.get_lr() 53 | 54 | def step_update(self, num_updates): 55 | """Update the learning rate after each update.""" 56 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 57 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 58 | self.optimizer.set_lr(self.warmup_factor * self.lr) 59 | return self.optimizer.get_lr() 60 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/inverse_square_root_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('inverse_sqrt') 12 | class InverseSquareRootSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (``--warmup-init-lr``) until the configured 17 | learning rate (``--lr``). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup:: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup:: 26 | 27 | decay_factor = args.lr * sqrt(args.warmup_updates) 28 | lr = decay_factor / sqrt(update_num) 29 | """ 30 | 31 | def __init__(self, args, optimizer): 32 | super().__init__(args, optimizer) 33 | if len(args.lr) > 1: 34 | raise ValueError( 35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 36 | ' Consider --lr-scheduler=fixed instead.' 37 | ) 38 | warmup_end_lr = args.lr[0] 39 | if args.warmup_init_lr < 0: 40 | args.warmup_init_lr = warmup_end_lr 41 | 42 | # linearly warmup for the first args.warmup_updates 43 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 44 | 45 | # then, decay prop. to the inverse square root of the update number 46 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 47 | 48 | # initial learning rate 49 | self.lr = args.warmup_init_lr 50 | self.optimizer.set_lr(self.lr) 51 | 52 | @staticmethod 53 | def add_args(parser): 54 | """Add arguments to the parser for this LR scheduler.""" 55 | # fmt: off 56 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 57 | help='warmup the learning rate linearly for the first N updates') 58 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 59 | help='initial learning rate during warmup phase; default is args.lr') 60 | # fmt: on 61 | 62 | def step(self, epoch, val_loss=None): 63 | """Update the learning rate at the end of the given epoch.""" 64 | super().step(epoch, val_loss) 65 | # we don't change the learning rate at epoch boundaries 66 | return self.optimizer.get_lr() 67 | 68 | def step_update(self, num_updates): 69 | """Update the learning rate after each update.""" 70 | if num_updates < self.args.warmup_updates: 71 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 72 | else: 73 | self.lr = self.decay_factor * num_updates**-0.5 74 | self.optimizer.set_lr(self.lr) 75 | return self.lr 76 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim.lr_scheduler 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('reduce_lr_on_plateau') 14 | class ReduceLROnPlateau(FairseqLRScheduler): 15 | """Decay the LR by a factor every time the validation loss plateaus.""" 16 | 17 | def __init__(self, args, optimizer): 18 | super().__init__(args, optimizer) 19 | if len(args.lr) > 1: 20 | raise ValueError( 21 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' 22 | ' Consider --lr-scheduler=fixed instead.' 23 | ) 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer.optimizer, patience=0, factor=args.lr_shrink) 26 | 27 | def state_dict(self): 28 | """Return the LR scheduler state dict.""" 29 | return { 30 | 'best': self.lr_scheduler.best, 31 | 'last_epoch': self.lr_scheduler.last_epoch, 32 | } 33 | 34 | def load_state_dict(self, state_dict): 35 | """Load an LR scheduler state dict.""" 36 | self.lr_scheduler.best = state_dict['best'] 37 | if 'last_epoch' in state_dict: 38 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 39 | 40 | def step(self, epoch, val_loss=None): 41 | """Update the learning rate at the end of the given epoch.""" 42 | if val_loss is not None: 43 | self.lr_scheduler.step(val_loss, epoch) 44 | else: 45 | self.lr_scheduler.last_epoch = epoch 46 | return self.optimizer.get_lr() 47 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/triangular_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('triangular') 14 | class TriangularSchedule(FairseqLRScheduler): 15 | """Assign LR based on a triangular cyclical schedule. 16 | 17 | See https://arxiv.org/pdf/1506.01186.pdf for details. 18 | """ 19 | 20 | def __init__(self, args, optimizer): 21 | super().__init__(args, optimizer) 22 | if len(args.lr) > 1: 23 | raise ValueError( 24 | 'Cannot use a fixed learning rate schedule with triangular.' 25 | ' Consider --lr-scheduler=fixed instead.' 26 | ) 27 | 28 | lr = args.lr[0] 29 | 30 | assert args.max_lr > lr, 'max_lr must be more than lr' 31 | self.min_lr = lr 32 | self.max_lr = args.max_lr 33 | self.stepsize = args.lr_period_updates // 2 34 | self.lr_shrink = args.lr_shrink 35 | self.shrink_min = args.shrink_min 36 | 37 | # initial learning rate 38 | self.lr = self.min_lr 39 | self.optimizer.set_lr(self.lr) 40 | 41 | @staticmethod 42 | def add_args(parser): 43 | """Add arguments to the parser for this LR scheduler.""" 44 | # fmt: off 45 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 46 | help='max learning rate, must be more than args.lr') 47 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 48 | help='initial number of updates per period (cycle length)') 49 | parser.add_argument('--shrink-min', action='store_true', 50 | help='if set, also shrinks min lr') 51 | # fmt: on 52 | 53 | def step(self, epoch, val_loss=None): 54 | """Update the learning rate at the end of the given epoch.""" 55 | super().step(epoch, val_loss) 56 | # we don't change the learning rate at epoch boundaries 57 | return self.optimizer.get_lr() 58 | 59 | def step_update(self, num_updates): 60 | """Update the learning rate after each update.""" 61 | cycle = math.floor(num_updates / (2 * self.stepsize)) 62 | 63 | lr_shrink = self.lr_shrink ** cycle 64 | max_lr = self.max_lr * lr_shrink 65 | if self.shrink_min: 66 | min_lr = self.min_lr * lr_shrink 67 | else: 68 | min_lr = self.min_lr 69 | 70 | x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) 71 | self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) 72 | 73 | self.optimizer.set_lr(self.lr) 74 | return self.lr 75 | -------------------------------------------------------------------------------- /fairseq/optim/nag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.optim.optimizer import Optimizer, required 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('nag') 14 | class FairseqNAG(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = NAG(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | 33 | 34 | class NAG(Optimizer): 35 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 36 | defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) 37 | super(NAG, self).__init__(params, defaults) 38 | 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | weight_decay = group['weight_decay'] 52 | momentum = group['momentum'] 53 | lr = group['lr'] 54 | lr_old = group.get('lr_old', lr) 55 | lr_correct = lr / lr_old 56 | 57 | for p in group['params']: 58 | if p.grad is None: 59 | continue 60 | 61 | d_p = p.grad.data 62 | param_state = self.state[p] 63 | if 'momentum_buffer' not in param_state: 64 | param_state['momentum_buffer'] = d_p.clone().zero_() 65 | 66 | buf = param_state['momentum_buffer'] 67 | 68 | if weight_decay != 0: 69 | p.data.mul_(1 - lr * weight_decay) 70 | p.data.add_(momentum * momentum * lr_correct, buf) 71 | p.data.add_(-(1 + momentum) * lr, d_p) 72 | 73 | buf.mul_(momentum * lr_correct).add_(-lr, d_p) 74 | 75 | group['lr_old'] = lr 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('sgd') 14 | class SGD(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | -------------------------------------------------------------------------------- /fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import multiprocessing 9 | import os 10 | import pdb 11 | import sys 12 | 13 | 14 | __all__ = ['set_trace'] 15 | 16 | 17 | _stdin = [None] 18 | _stdin_lock = multiprocessing.Lock() 19 | try: 20 | _stdin_fd = sys.stdin.fileno() 21 | except Exception: 22 | _stdin_fd = None 23 | 24 | 25 | class MultiprocessingPdb(pdb.Pdb): 26 | """A Pdb wrapper that works in a multiprocessing environment. 27 | 28 | Usage: `from fairseq import pdb; pdb.set_trace()` 29 | """ 30 | 31 | def __init__(self): 32 | pdb.Pdb.__init__(self, nosigint=True) 33 | 34 | def _cmdloop(self): 35 | stdin_bak = sys.stdin 36 | with _stdin_lock: 37 | try: 38 | if _stdin_fd is not None: 39 | if not _stdin[0]: 40 | _stdin[0] = os.fdopen(_stdin_fd) 41 | sys.stdin = _stdin[0] 42 | self.cmdloop() 43 | finally: 44 | sys.stdin = stdin_bak 45 | 46 | 47 | def set_trace(): 48 | pdb = MultiprocessingPdb() 49 | pdb.set_trace(sys._getframe().f_back) 50 | -------------------------------------------------------------------------------- /fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import importlib 10 | import os 11 | 12 | from .fairseq_task import FairseqTask 13 | 14 | TASK_REGISTRY = {} 15 | TASK_CLASS_NAMES = set() 16 | 17 | 18 | def setup_task(args): 19 | return TASK_REGISTRY[args.task].setup_task(args) 20 | 21 | 22 | def register_task(name): 23 | """ 24 | New tasks can be added to fairseq with the 25 | :func:`~fairseq.tasks.register_task` function decorator. 26 | 27 | For example:: 28 | 29 | @register_task('classification') 30 | class ClassificationTask(FairseqTask): 31 | (...) 32 | 33 | .. note:: 34 | 35 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 36 | interface. 37 | 38 | Please see the 39 | 40 | Args: 41 | name (str): the name of the task 42 | """ 43 | 44 | def register_task_cls(cls): 45 | if name in TASK_REGISTRY: 46 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 47 | if not issubclass(cls, FairseqTask): 48 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 49 | if cls.__name__ in TASK_CLASS_NAMES: 50 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 51 | TASK_REGISTRY[name] = cls 52 | TASK_CLASS_NAMES.add(cls.__name__) 53 | return cls 54 | 55 | return register_task_cls 56 | 57 | 58 | # automatically import any Python files in the tasks/ directory 59 | for file in os.listdir(os.path.dirname(__file__)): 60 | if file.endswith('.py') and not file.startswith('_'): 61 | task_name = file[:file.find('.py')] 62 | importlib.import_module('fairseq.tasks.' + task_name) 63 | 64 | # expose `task_parser` for sphinx 65 | if task_name in TASK_REGISTRY: 66 | parser = argparse.ArgumentParser(add_help=False) 67 | group_task = parser.add_argument_group('Task name') 68 | # fmt: off 69 | group_task.add_argument('--task', metavar=task_name, 70 | help='Enable this task with: ``--task=' + task_name + '``') 71 | # fmt: on 72 | group_args = parser.add_argument_group('Additional command-line arguments') 73 | TASK_REGISTRY[task_name].add_args(group_args) 74 | globals()[task_name + '_parser'] = parser 75 | 76 | 77 | def get_task(name): 78 | return TASK_REGISTRY[name] 79 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import re 9 | 10 | SPACE_NORMALIZER = re.compile(r"\s+") 11 | 12 | 13 | def tokenize_line(line): 14 | line = SPACE_NORMALIZER.sub(" ", line) 15 | line = line.strip() 16 | return line.split() 17 | -------------------------------------------------------------------------------- /fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /fairseq_cli/eval_lm.py: -------------------------------------------------------------------------------- 1 | ../eval_lm.py -------------------------------------------------------------------------------- /fairseq_cli/generate.py: -------------------------------------------------------------------------------- 1 | ../generate.py -------------------------------------------------------------------------------- /fairseq_cli/interactive.py: -------------------------------------------------------------------------------- 1 | ../interactive.py -------------------------------------------------------------------------------- /fairseq_cli/preprocess.py: -------------------------------------------------------------------------------- 1 | ../preprocess.py -------------------------------------------------------------------------------- /fairseq_cli/score.py: -------------------------------------------------------------------------------- 1 | ../score.py -------------------------------------------------------------------------------- /fairseq_cli/setup.py: -------------------------------------------------------------------------------- 1 | ../setup.py -------------------------------------------------------------------------------- /fairseq_cli/train.py: -------------------------------------------------------------------------------- 1 | ../train.py -------------------------------------------------------------------------------- /fairseq_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/fairseq_logo.png -------------------------------------------------------------------------------- /preprocess.sh: -------------------------------------------------------------------------------- 1 | src=zh 2 | tgt=en 3 | TEXT=../LDC 4 | tag=conv 5 | output=data-bin/$tag 6 | srcdict=$TEXT/dict.$src.txt 7 | tgtdict=$TEXT/dict.$tgt.txt 8 | 9 | python3 preprocess.py --source-lang $src --target-lang $tgt --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test,$TEXT/test1,$TEXT/test2 --destdir $output --workers 32 10 | -------------------------------------------------------------------------------- /rerank.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | fr= open(sys.argv[1],'r',encoding="utf-8") 4 | fw = open(sys.argv[2],'w',encoding="utf-8") 5 | dict = {} 6 | count = 0 7 | for line in fr.readlines(): 8 | line = line.strip().replace('\n','').split('\t') 9 | dict[int(line[0])]=line[1] 10 | count+=1 11 | #print(count) 12 | 13 | sorted_list = sorted(dict.items(),key=lambda x:x[0]) 14 | 15 | for item in sorted_list: 16 | fw.write(item[1]+'\n') 17 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | BLEU scoring of generated translations against reference translations. 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | 16 | from fairseq import bleu 17 | from fairseq.data import dictionary 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') 22 | # fmt: off 23 | parser.add_argument('-s', '--sys', default='-', help='system output') 24 | parser.add_argument('-r', '--ref', required=True, help='references') 25 | parser.add_argument('-o', '--order', default=4, metavar='N', 26 | type=int, help='consider ngrams up to this order') 27 | parser.add_argument('--ignore-case', action='store_true', 28 | help='case-insensitive scoring') 29 | parser.add_argument('--sacrebleu', action='store_true', 30 | help='score with sacrebleu') 31 | # fmt: on 32 | return parser 33 | 34 | 35 | def main(): 36 | parser = get_parser() 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | assert args.sys == '-' or os.path.exists(args.sys), \ 41 | "System output file {} does not exist".format(args.sys) 42 | assert os.path.exists(args.ref), \ 43 | "Reference file {} does not exist".format(args.ref) 44 | 45 | dict = dictionary.Dictionary() 46 | 47 | def readlines(fd): 48 | for line in fd.readlines(): 49 | if args.ignore_case: 50 | yield line.lower() 51 | else: 52 | yield line 53 | 54 | if args.sacrebleu: 55 | import sacrebleu 56 | 57 | def score(fdsys): 58 | with open(args.ref) as fdref: 59 | print(sacrebleu.corpus_bleu(fdsys, [fdref])) 60 | else: 61 | def score(fdsys): 62 | with open(args.ref) as fdref: 63 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 64 | for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): 65 | sys_tok = dict.encode_line(sys_tok) 66 | ref_tok = dict.encode_line(ref_tok) 67 | scorer.add(ref_tok, sys_tok) 68 | print(scorer.result_string(args.order)) 69 | 70 | if args.sys == '-': 71 | score(sys.stdin) 72 | else: 73 | with open(args.sys, 'r') as f: 74 | score(f) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/build_sym_alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | """ 10 | Use this script in order to build symmetric alignments for your translation 11 | dataset. 12 | This script depends on fast_align and mosesdecoder tools. You will need to 13 | build those before running the script. 14 | fast_align: 15 | github: http://github.com/clab/fast_align 16 | instructions: follow the instructions in README.md 17 | mosesdecoder: 18 | github: http://github.com/moses-smt/mosesdecoder 19 | instructions: http://www.statmt.org/moses/?n=Development.GetStarted 20 | The script produces the following files under --output_dir: 21 | text.joined - concatenation of lines from the source_file and the 22 | target_file. 23 | align.forward - forward pass of fast_align. 24 | align.backward - backward pass of fast_align. 25 | aligned.sym_heuristic - symmetrized alignment. 26 | """ 27 | 28 | import argparse 29 | import os 30 | from itertools import zip_longest 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='symmetric alignment builer') 35 | # fmt: off 36 | parser.add_argument('--fast_align_dir', 37 | help='path to fast_align build directory') 38 | parser.add_argument('--mosesdecoder_dir', 39 | help='path to mosesdecoder root directory') 40 | parser.add_argument('--sym_heuristic', 41 | help='heuristic to use for symmetrization', 42 | default='grow-diag-final-and') 43 | parser.add_argument('--source_file', 44 | help='path to a file with sentences ' 45 | 'in the source language') 46 | parser.add_argument('--target_file', 47 | help='path to a file with sentences ' 48 | 'in the target language') 49 | parser.add_argument('--output_dir', 50 | help='output directory') 51 | # fmt: on 52 | args = parser.parse_args() 53 | 54 | fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') 55 | symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal') 56 | sym_fast_align_bin = os.path.join( 57 | args.mosesdecoder_dir, 'scripts', 'ems', 58 | 'support', 'symmetrize-fast-align.perl') 59 | 60 | # create joined file 61 | joined_file = os.path.join(args.output_dir, 'text.joined') 62 | with open(args.source_file, 'r', encoding='utf-8') as src, open(args.target_file, 'r', encoding='utf-8') as tgt: 63 | with open(joined_file, 'w', encoding='utf-8') as joined: 64 | for s, t in zip_longest(src, tgt): 65 | print('{} ||| {}'.format(s.strip(), t.strip()), file=joined) 66 | 67 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 68 | 69 | # run forward alignment 70 | fwd_align_file = os.path.join(args.output_dir, 'align.forward') 71 | fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format( 72 | FASTALIGN=fast_align_bin, 73 | JOINED=joined_file, 74 | FWD=fwd_align_file) 75 | assert os.system(fwd_fast_align_cmd) == 0 76 | 77 | # run backward alignment 78 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 79 | bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format( 80 | FASTALIGN=fast_align_bin, 81 | JOINED=joined_file, 82 | BWD=bwd_align_file) 83 | assert os.system(bwd_fast_align_cmd) == 0 84 | 85 | # run symmetrization 86 | sym_out_file = os.path.join(args.output_dir, 'aligned') 87 | sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format( 88 | SYMFASTALIGN=sym_fast_align_bin, 89 | FWD=fwd_align_file, 90 | BWD=bwd_align_file, 91 | SRC=args.source_file, 92 | TGT=args.target_file, 93 | OUT=sym_out_file, 94 | HEURISTIC=args.sym_heuristic, 95 | SYMAL=symal_bin 96 | ) 97 | assert os.system(sym_cmd) == 0 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_dictionary.lua <dict.th7> 9 | require 'fairseq' 10 | require 'torch' 11 | require 'paths' 12 | 13 | if #arg < 1 then 14 | print('usage: convert_dictionary.lua <dict.th7>') 15 | os.exit(1) 16 | end 17 | if not paths.filep(arg[1]) then 18 | print('error: file does not exit: ' .. arg[1]) 19 | os.exit(1) 20 | end 21 | 22 | dict = torch.load(arg[1]) 23 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 24 | assert(dst:match('.txt$')) 25 | 26 | f = io.open(dst, 'w') 27 | for idx, symbol in ipairs(dict.index_to_symbol) do 28 | if idx > dict.cutoff then 29 | break 30 | end 31 | f:write(symbol) 32 | f:write(' ') 33 | f:write(dict.index_to_freq[idx]) 34 | f:write('\n') 35 | end 36 | f:close() 37 | -------------------------------------------------------------------------------- /scripts/convert_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_model.lua <model_epoch1.th7> 9 | require 'torch' 10 | local fairseq = require 'fairseq' 11 | 12 | model = torch.load(arg[1]) 13 | 14 | function find_weight_norm(container, module) 15 | for _, wn in ipairs(container:listModules()) do 16 | if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then 17 | return wn 18 | end 19 | end 20 | end 21 | 22 | function push_state(dict, key, module) 23 | if torch.type(module) == 'nn.Linear' then 24 | local wn = find_weight_norm(model.module, module) 25 | assert(wn) 26 | dict[key .. '.weight_v'] = wn.v:float() 27 | dict[key .. '.weight_g'] = wn.g:float() 28 | elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then 29 | local wn = find_weight_norm(model.module, module) 30 | assert(wn) 31 | local v = wn.v:float():view(wn.viewOut):transpose(2, 3) 32 | dict[key .. '.weight_v'] = v 33 | dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) 34 | else 35 | dict[key .. '.weight'] = module.weight:float() 36 | end 37 | if module.bias then 38 | dict[key .. '.bias'] = module.bias:float() 39 | end 40 | end 41 | 42 | encoder_dict = {} 43 | decoder_dict = {} 44 | combined_dict = {} 45 | 46 | function encoder_state(encoder) 47 | luts = encoder:findModules('nn.LookupTable') 48 | push_state(encoder_dict, 'embed_tokens', luts[1]) 49 | push_state(encoder_dict, 'embed_positions', luts[2]) 50 | 51 | fcs = encoder:findModules('nn.Linear') 52 | assert(#fcs >= 2) 53 | local nInputPlane = fcs[1].weight:size(1) 54 | push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) 55 | push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) 56 | 57 | for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do 58 | push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) 59 | if nInputPlane ~= module.weight:size(3) / 2 then 60 | push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 61 | end 62 | nInputPlane = module.weight:size(3) / 2 63 | end 64 | assert(#fcs == 0) 65 | end 66 | 67 | function decoder_state(decoder) 68 | luts = decoder:findModules('nn.LookupTable') 69 | push_state(decoder_dict, 'embed_tokens', luts[1]) 70 | push_state(decoder_dict, 'embed_positions', luts[2]) 71 | 72 | fcs = decoder:findModules('nn.Linear') 73 | local nInputPlane = fcs[1].weight:size(1) 74 | push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) 75 | push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) 76 | push_state(decoder_dict, 'fc3', fcs[#fcs]) 77 | 78 | table.remove(fcs, #fcs) 79 | table.remove(fcs, #fcs) 80 | 81 | for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do 82 | if nInputPlane ~= module.weight:size(3) / 2 then 83 | push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 84 | end 85 | nInputPlane = module.weight:size(3) / 2 86 | 87 | local prefix = 'attention.' .. tostring(i - 1) 88 | push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) 89 | push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) 90 | push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) 91 | end 92 | assert(#fcs == 0) 93 | end 94 | 95 | 96 | _encoder = model.module.modules[2] 97 | _decoder = model.module.modules[3] 98 | 99 | encoder_state(_encoder) 100 | decoder_state(_decoder) 101 | 102 | for k, v in pairs(encoder_dict) do 103 | combined_dict['encoder.' .. k] = v 104 | end 105 | for k, v in pairs(decoder_dict) do 106 | combined_dict['decoder.' .. k] = v 107 | end 108 | 109 | 110 | torch.save('state_dict.t7', combined_dict) 111 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | import argparse 11 | 12 | from fairseq.data import dictionary 13 | from fairseq.data import IndexedDataset 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser( 18 | description='writes text from binarized file to stdout') 19 | # fmt: off 20 | parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words') 21 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 22 | # fmt: on 23 | 24 | return parser 25 | 26 | 27 | def main(args): 28 | dict = dictionary.Dictionary.load(args.dict) 29 | ds = IndexedDataset(args.input, fix_lua_indexing=True) 30 | for tensor_line in ds: 31 | print(dict.string(tensor_line)) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = get_parser() 36 | args = parser.parse_args() 37 | main(args) 38 | -------------------------------------------------------------------------------- /scripts/sacrebleu_pregen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | echo 'Cloning Moses github repository (for tokenization scripts)...' 15 | git clone https://github.com/moses-smt/mosesdecoder.git 16 | 17 | SCRIPTS=mosesdecoder/scripts 18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | perl $DETOKENIZER -l $TGTLANG \ 25 | | sed "s/ - /-/g" \ 26 | > $GEN.sorted.detok 27 | 28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 29 | -------------------------------------------------------------------------------- /scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", required=True, 18 | help="sentencepiece model to use for decoding") 19 | parser.add_argument("--input", required=True, help="input file to decode") 20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 21 | args = parser.parse_args() 22 | 23 | sp = spm.SentencePieceProcessor() 24 | sp.Load(args.model) 25 | 26 | if args.input_format == "piece": 27 | def decode(l): 28 | return "".join(sp.DecodePieces(l)) 29 | elif args.input_format == "id": 30 | def decode(l): 31 | return "".join(sp.DecodeIds(l)) 32 | else: 33 | raise NotImplementedError 34 | 35 | def tok2int(tok): 36 | # remap reference-side <unk> (represented as <<unk>>) to 0 37 | return int(tok) if tok != "<<unk>>" else 0 38 | 39 | with open(args.input, "r", encoding="utf-8") as h: 40 | for line in h: 41 | print(decode(list(map(tok2int, line.rstrip().split())))) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/spm_encode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | import contextlib 12 | import sys 13 | 14 | import sentencepiece as spm 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model", required=True, 20 | help="sentencepiece model to use for encoding") 21 | parser.add_argument("--inputs", nargs="+", default=['-'], 22 | help="input files to filter/encode") 23 | parser.add_argument("--outputs", nargs="+", default=['-'], 24 | help="path to save encoded outputs") 25 | parser.add_argument("--output_format", choices=["piece", "id"], default="piece") 26 | parser.add_argument("--min-len", type=int, metavar="N", 27 | help="filter sentence pairs with fewer than N tokens") 28 | parser.add_argument("--max-len", type=int, metavar="N", 29 | help="filter sentence pairs with more than N tokens") 30 | args = parser.parse_args() 31 | 32 | assert len(args.inputs) == len(args.outputs), \ 33 | "number of input and output paths should match" 34 | 35 | sp = spm.SentencePieceProcessor() 36 | sp.Load(args.model) 37 | 38 | if args.output_format == "piece": 39 | def encode(l): 40 | return sp.EncodeAsPieces(l) 41 | elif args.output_format == "id": 42 | def encode(l): 43 | return list(map(str, sp.EncodeAsIds(l))) 44 | else: 45 | raise NotImplementedError 46 | 47 | if args.min_len is not None or args.max_len is not None: 48 | def valid(line): 49 | return ( 50 | (args.min_len is None or len(line) >= args.min_len) 51 | and (args.max_len is None or len(line) <= args.max_len) 52 | ) 53 | else: 54 | def valid(lines): 55 | return True 56 | 57 | with contextlib.ExitStack() as stack: 58 | inputs = [ 59 | stack.enter_context(open(input, "r", encoding="utf-8")) \ 60 | if input != "-" else sys.stdin 61 | for input in args.inputs 62 | ] 63 | outputs = [ 64 | stack.enter_context(open(output, "w", encoding="utf-8")) \ 65 | if output != "-" else sys.stdout 66 | for output in args.outputs 67 | ] 68 | 69 | stats = { 70 | "num_empty": 0, 71 | "num_filtered": 0, 72 | } 73 | 74 | def encode_line(line): 75 | line = line.strip() 76 | if len(line) > 0: 77 | line = encode(line) 78 | if valid(line): 79 | return line 80 | else: 81 | stats["num_filtered"] += 1 82 | else: 83 | stats["num_empty"] += 1 84 | return None 85 | 86 | for i, lines in enumerate(zip(*inputs), start=1): 87 | enc_lines = list(map(encode_line, lines)) 88 | if not any(enc_line is None for enc_line in enc_lines): 89 | for enc_line, output_h in zip(enc_lines, outputs): 90 | print(" ".join(enc_line), file=output_h) 91 | if i % 10000 == 0: 92 | print("processed {} lines".format(i), file=sys.stderr) 93 | 94 | print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) 95 | print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import shlex 11 | import sys 12 | 13 | import sentencepiece as spm 14 | 15 | 16 | if __name__ == "__main__": 17 | spm.SentencePieceTrainer.Train(" ".join(map(shlex.quote, sys.argv[1:]))) 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from setuptools import setup, find_packages, Extension 10 | import sys 11 | 12 | 13 | if sys.version_info < (3,): 14 | sys.exit('Sorry, Python3 is required for fairseq.') 15 | 16 | with open('README.md') as f: 17 | readme = f.read() 18 | 19 | 20 | bleu = Extension( 21 | 'fairseq.libbleu', 22 | sources=[ 23 | 'fairseq/clib/libbleu/libbleu.cpp', 24 | 'fairseq/clib/libbleu/module.cpp', 25 | ], 26 | extra_compile_args=['-std=c++11'], 27 | ) 28 | 29 | 30 | setup( 31 | name='fairseq', 32 | version='0.6.2', 33 | description='Facebook AI Research Sequence-to-Sequence Toolkit', 34 | url='https://github.com/pytorch/fairseq', 35 | classifiers=[ 36 | 'Intended Audience :: Science/Research', 37 | 'License :: OSI Approved :: BSD License', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 40 | ], 41 | long_description=readme, 42 | install_requires=[ 43 | 'cffi', 44 | 'numpy', 45 | 'sacrebleu', 46 | # don't include torch, to support both release and nightly builds 47 | #'torch', 48 | 'tqdm', 49 | ], 50 | packages=find_packages(exclude=['scripts', 'tests']), 51 | ext_modules=[bleu], 52 | test_suite='tests', 53 | entry_points={ 54 | 'console_scripts': [ 55 | 'fairseq-eval-lm = fairseq_cli.eval_lm:cli_main', 56 | 'fairseq-generate = fairseq_cli.generate:cli_main', 57 | 'fairseq-interactive = fairseq_cli.interactive:cli_main', 58 | 'fairseq-preprocess = fairseq_cli.preprocess:cli_main', 59 | 'fairseq-train = fairseq_cli.train:cli_main', 60 | 'fairseq-score = fairseq_cli.score:main', 61 | ], 62 | }, 63 | ) 64 | -------------------------------------------------------------------------------- /stack.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import sys 3 | 4 | import torch 5 | #Number of parameters for each encoder layer 6 | # rpr-network = 13,base-network = 12 7 | num_of_layerpara = 13 8 | strategy = 1 9 | 10 | def main(): 11 | ckpt = torch.load(sys.argv[1]) 12 | lst = [] 13 | # Number of copy encoder layers 14 | counter_layer = int(sys.argv[3]) 15 | #Copy all layers before,such as 6->12->24->48 16 | if strategy == 0: 17 | for k, v in ckpt['model'].items(): 18 | k_split = k.split('.') 19 | if k_split[0] == 'encoder' and k_split[1] == 'layers': 20 | l_id = int(k_split[2]) 21 | k_split[2] = str(l_id + ckpt['args'].encoder_layers) 22 | new_k = '.'.join(k_split) 23 | lst.append([new_k, v.clone()]) 24 | if k_split[0] == 'encoder' and k_split[1] == 'history' and k_split[2] == 'layer_norms': 25 | l_id = int(k_split[3]) 26 | k_split[3] = str(l_id + ckpt['args'].encoder_layers) 27 | new_k = '.'.join(k_split) 28 | lst.append([new_k, v.clone()]) 29 | #sdt g top-most 30 | elif strategy == 1: 31 | current_layers = ckpt['args'].encoder_layers 32 | count_layer = 0 33 | for k, v in ckpt['model'].items(): 34 | k_split = k.split('.') 35 | if k_split[0] == 'encoder' and k_split[1] == 'layers' and int(k_split[2]) == current_layers - counter_layer: 36 | l_id = int(k_split[2]) 37 | k_split[2] = str(l_id + int(sys.argv[3])) 38 | new_k = '.'.join(k_split) 39 | lst.append([new_k, v.clone()]) 40 | count_layer += 1 41 | if count_layer == num_of_layerpara: 42 | counter_layer -= 1 43 | count_layer = 0 44 | if k_split[0] == 'encoder' and k_split[1] == 'history' and k_split[2] == 'layer_norms': 45 | if int(k_split[3]) == len(ckpt['args'].k)-2: 46 | l_id = int(k_split[3]) 47 | k_split[3] = str(l_id + 1) 48 | new_k = '.'.join(k_split) 49 | lst.append([new_k, v.clone()]) 50 | # top only 51 | elif strategy == 2: 52 | current_layers = ckpt['args'].encoder_layers 53 | count_layer = 0 54 | num = 1 55 | for k, v in ckpt['model'].items(): 56 | k_split = k.split('.') 57 | if k_split[0] == 'encoder' and k_split[1] == 'layers' and int(k_split[2]) == current_layers - 1: 58 | l_id = int(k_split[2]) 59 | for i in range(counter_layer): 60 | k_split[2] = str(l_id + i + 1) 61 | new_k = '.'.join(k_split) 62 | lst.append([new_k, v.clone()]) 63 | if k_split[0] == 'encoder' and k_split[1] == 'history' and k_split[2] == 'layer_norms': 64 | if int(k_split[3]) == len(ckpt['args'].k) - 2: 65 | l_id = int(k_split[3]) 66 | k_split[3] = str(l_id + 1) 67 | new_k = '.'.join(k_split) 68 | lst.append([new_k, v.clone()]) 69 | #Interpolation no sparse connections 70 | elif strategy == 3: 71 | layer = 0 72 | count = 0 73 | for k, v in ckpt['model'].items(): 74 | # print(k) 75 | k_split = k.split('.') 76 | if k_split[0] == 'encoder' and k_split[1] == 'layers': 77 | l_id = int(k_split[2]) + layer 78 | k_split[2] = str(l_id) 79 | new_k1 = '.'.join(k_split) 80 | k_split[2] = str(l_id + 1) 81 | new_k2 = '.'.join(k_split) 82 | lst.append([new_k1, v]) 83 | lst.append([new_k2, v.clone()]) 84 | count += 1 85 | if count == 13: 86 | layer = layer + 1 87 | count = 0 88 | #exit() 89 | for k, v in lst: 90 | ckpt['model'][k] = v 91 | 92 | 93 | if strategy == 0 or strategy == 3: 94 | ckpt['args'].encoder_layers *= 2 95 | elif strategy == 1 or strategy == 2: 96 | ckpt['args'].encoder_layers += int(sys.argv[3]) 97 | torch.save(ckpt, sys.argv[2]) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libeineu/SDT-Training/d33a836c1b3258748ec11c4d64998e0dcc9792df/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_average_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import collections 9 | import os 10 | import tempfile 11 | import unittest 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from scripts.average_checkpoints import average_checkpoints 17 | 18 | 19 | class TestAverageCheckpoints(unittest.TestCase): 20 | def test_average_checkpoints(self): 21 | params_0 = collections.OrderedDict( 22 | [ 23 | ('a', torch.DoubleTensor([100.0])), 24 | ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), 25 | ('c', torch.IntTensor([7, 8, 9])), 26 | ] 27 | ) 28 | params_1 = collections.OrderedDict( 29 | [ 30 | ('a', torch.DoubleTensor([1.0])), 31 | ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), 32 | ('c', torch.IntTensor([2, 2, 2])), 33 | ] 34 | ) 35 | params_avg = collections.OrderedDict( 36 | [ 37 | ('a', torch.DoubleTensor([50.5])), 38 | ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), 39 | # We expect truncation for integer division 40 | ('c', torch.IntTensor([4, 5, 5])), 41 | ] 42 | ) 43 | 44 | fd_0, path_0 = tempfile.mkstemp() 45 | fd_1, path_1 = tempfile.mkstemp() 46 | torch.save(collections.OrderedDict([('model', params_0)]), path_0) 47 | torch.save(collections.OrderedDict([('model', params_1)]), path_1) 48 | 49 | output = average_checkpoints([path_0, path_1])['model'] 50 | 51 | os.close(fd_0) 52 | os.remove(path_0) 53 | os.close(fd_1) 54 | os.remove(path_1) 55 | 56 | for (k_expected, v_expected), (k_out, v_out) in zip( 57 | params_avg.items(), output.items()): 58 | self.assertEqual( 59 | k_expected, k_out, 'Key mismatch - expected {} but found {}. ' 60 | '(Expected list of keys: {} vs actual list of keys: {})'.format( 61 | k_expected, k_out, params_avg.keys(), output.keys() 62 | ) 63 | ) 64 | np.testing.assert_allclose( 65 | v_expected.numpy(), 66 | v_out.numpy(), 67 | err_msg='Tensor value mismatch for key {}'.format(k_expected) 68 | ) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /tests/test_character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | 11 | from fairseq.data import Dictionary 12 | from fairseq.modules import CharacterTokenEmbedder 13 | 14 | 15 | class TestCharacterTokenEmbedder(unittest.TestCase): 16 | def test_character_token_embedder(self): 17 | vocab = Dictionary() 18 | vocab.add_symbol('hello') 19 | vocab.add_symbol('there') 20 | 21 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2) 22 | 23 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] 24 | max_len = max(len(s) for s in test_sents) 25 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) 26 | for i in range(len(test_sents)): 27 | input[i][0] = vocab.eos() 28 | for j in range(len(test_sents[i])): 29 | input[i][j + 1] = vocab.index(test_sents[i][j]) 30 | input[i][j + 2] = vocab.eos() 31 | embs = embedder(input) 32 | 33 | assert embs.size() == (len(test_sents), max_len + 2, 5) 34 | self.assertAlmostEqual(embs[0][0], embs[1][0]) 35 | self.assertAlmostEqual(embs[0][0], embs[0][-1]) 36 | self.assertAlmostEqual(embs[0][1], embs[2][1]) 37 | self.assertAlmostEqual(embs[0][3], embs[1][1]) 38 | 39 | embs.sum().backward() 40 | assert embedder.char_embeddings.weight.grad is not None 41 | 42 | def assertAlmostEqual(self, t1, t2): 43 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 44 | self.assertLess((t1 - t2).abs().max(), 1e-6) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_convtbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | from fairseq.modules import ConvTBC 11 | import torch.nn as nn 12 | 13 | 14 | class TestConvTBC(unittest.TestCase): 15 | 16 | def test_convtbc(self): 17 | # ksz, in_channels, out_channels 18 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) 19 | # out_channels, in_channels, ksz 20 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1) 21 | 22 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) 23 | conv_tbc.bias.data.copy_(conv1d.bias.data) 24 | 25 | input_tbc = torch.randn(7, 2, 4, requires_grad=True) 26 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2) 27 | input1d.requires_grad = True 28 | 29 | output_tbc = conv_tbc(input_tbc) 30 | output1d = conv1d(input1d) 31 | 32 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) 33 | 34 | grad_tbc = torch.randn(output_tbc.size()) 35 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() 36 | 37 | output_tbc.backward(grad_tbc) 38 | output1d.backward(grad1d) 39 | 40 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data) 41 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) 42 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data) 43 | 44 | def assertAlmostEqual(self, t1, t2): 45 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 46 | self.assertLess((t1 - t2).abs().max(), 1e-4) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/test_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import tempfile 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.data import Dictionary 14 | 15 | 16 | class TestDictionary(unittest.TestCase): 17 | 18 | def test_finalize(self): 19 | txt = [ 20 | 'A B C D', 21 | 'B C D', 22 | 'C D', 23 | 'D', 24 | ] 25 | ref_ids1 = list(map(torch.IntTensor, [ 26 | [4, 5, 6, 7, 2], 27 | [5, 6, 7, 2], 28 | [6, 7, 2], 29 | [7, 2], 30 | ])) 31 | ref_ids2 = list(map(torch.IntTensor, [ 32 | [7, 6, 5, 4, 2], 33 | [6, 5, 4, 2], 34 | [5, 4, 2], 35 | [4, 2], 36 | ])) 37 | 38 | # build dictionary 39 | d = Dictionary() 40 | for line in txt: 41 | d.encode_line(line, add_if_not_exist=True) 42 | 43 | def get_ids(dictionary): 44 | ids = [] 45 | for line in txt: 46 | ids.append(dictionary.encode_line(line, add_if_not_exist=False)) 47 | return ids 48 | 49 | def assertMatch(ids, ref_ids): 50 | for toks, ref_toks in zip(ids, ref_ids): 51 | self.assertEqual(toks.size(), ref_toks.size()) 52 | self.assertEqual(0, (toks != ref_toks).sum().item()) 53 | 54 | ids = get_ids(d) 55 | assertMatch(ids, ref_ids1) 56 | 57 | # check finalized dictionary 58 | d.finalize() 59 | finalized_ids = get_ids(d) 60 | assertMatch(finalized_ids, ref_ids2) 61 | 62 | # write to disk and reload 63 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict: 64 | d.save(tmp_dict.name) 65 | d = Dictionary.load(tmp_dict.name) 66 | reload_ids = get_ids(d) 67 | assertMatch(reload_ids, ref_ids2) 68 | assertMatch(finalized_ids, reload_ids) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /tests/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | from fairseq.data import iterators 11 | 12 | 13 | class TestIterators(unittest.TestCase): 14 | 15 | def test_counting_iterator(self): 16 | x = list(range(10)) 17 | itr = iterators.CountingIterator(x) 18 | self.assertTrue(itr.has_next()) 19 | self.assertEqual(next(itr), 0) 20 | self.assertEqual(next(itr), 1) 21 | itr.skip(3) 22 | self.assertEqual(next(itr), 5) 23 | itr.skip(3) 24 | self.assertEqual(next(itr), 9) 25 | self.assertFalse(itr.has_next()) 26 | 27 | 28 | if __name__ == '__main__': 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /tests/test_reproducibility.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import contextlib 9 | from io import StringIO 10 | import json 11 | import os 12 | import tempfile 13 | import unittest 14 | 15 | from . import test_binaries 16 | 17 | 18 | class TestReproducibility(unittest.TestCase): 19 | 20 | def _test_reproducibility(self, name, extra_flags=None): 21 | if extra_flags is None: 22 | extra_flags = [] 23 | 24 | with tempfile.TemporaryDirectory(name) as data_dir: 25 | with contextlib.redirect_stdout(StringIO()): 26 | test_binaries.create_dummy_data(data_dir) 27 | test_binaries.preprocess_translation_data(data_dir) 28 | 29 | # train epochs 1 and 2 together 30 | stdout = StringIO() 31 | with contextlib.redirect_stdout(stdout): 32 | test_binaries.train_translation_model( 33 | data_dir, 'fconv_iwslt_de_en', [ 34 | '--dropout', '0.0', 35 | '--log-format', 'json', 36 | '--log-interval', '1', 37 | '--max-epoch', '3', 38 | ] + extra_flags, 39 | ) 40 | stdout = stdout.getvalue() 41 | train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2]) 42 | 43 | # train epoch 2, resuming from previous checkpoint 1 44 | os.rename( 45 | os.path.join(data_dir, 'checkpoint1.pt'), 46 | os.path.join(data_dir, 'checkpoint_last.pt'), 47 | ) 48 | stdout = StringIO() 49 | with contextlib.redirect_stdout(stdout): 50 | test_binaries.train_translation_model( 51 | data_dir, 'fconv_iwslt_de_en', [ 52 | '--dropout', '0.0', 53 | '--log-format', 'json', 54 | '--log-interval', '1', 55 | '--max-epoch', '3', 56 | ] + extra_flags, 57 | ) 58 | stdout = stdout.getvalue() 59 | train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2]) 60 | 61 | def cast(s): 62 | return round(float(s), 3) 63 | 64 | for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']: 65 | self.assertEqual(cast(train_log[k]), cast(train_res_log[k])) 66 | for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']: 67 | self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k])) 68 | 69 | def test_reproducibility(self): 70 | self._test_reproducibility('test_reproducibility') 71 | 72 | def test_reproducibility_fp16(self): 73 | self._test_reproducibility('test_reproducibility_fp16', [ 74 | '--fp16', 75 | '--fp16-init-scale', '4096', 76 | ]) 77 | 78 | def test_reproducibility_memory_efficient_fp16(self): 79 | self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [ 80 | '--memory-efficient-fp16', 81 | '--fp16-init-scale', '4096', 82 | ]) 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/test_token_block_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fairseq.data import TokenBlockDataset 13 | 14 | import tests.utils as test_utils 15 | 16 | 17 | class TestTokenBlockDataset(unittest.TestCase): 18 | 19 | def _build_dataset(self, data, **kwargs): 20 | sizes = [len(x) for x in data] 21 | underlying_ds = test_utils.TestDataset(data) 22 | return TokenBlockDataset(underlying_ds, sizes, **kwargs) 23 | 24 | def test_eos_break_mode(self): 25 | data = [ 26 | torch.LongTensor([5, 4, 3, 2, 1]), 27 | torch.LongTensor([1]), # this should be filtered 28 | torch.LongTensor([8, 7, 6, 1]), 29 | ] 30 | ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') 31 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) 32 | self.assertEqual(ds[1].tolist(), [1]) 33 | self.assertEqual(ds[2].tolist(), [8, 7, 6, 1]) 34 | 35 | data = [ 36 | torch.LongTensor([5, 4, 3, 2, 1]), 37 | torch.LongTensor([8, 7, 6, 1]), 38 | torch.LongTensor([1]), # this should be filtered 39 | ] 40 | ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') 41 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) 42 | self.assertEqual(ds[1].tolist(), [8, 7, 6, 1]) 43 | self.assertEqual(ds[2].tolist(), [1]) 44 | 45 | def test_block_break_mode(self): 46 | data = [ 47 | torch.LongTensor([5, 4, 3, 2, 1]), 48 | torch.LongTensor([8, 7, 6, 1]), 49 | torch.LongTensor([9, 1]), 50 | ] 51 | ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none') 52 | self.assertEqual(ds[0].tolist(), [5, 4, 3]) 53 | self.assertEqual(ds[1].tolist(), [2, 1, 8]) 54 | self.assertEqual(ds[2].tolist(), [7, 6, 1]) 55 | self.assertEqual(ds[3].tolist(), [9, 1]) 56 | 57 | def test_complete_break_mode(self): 58 | data = [ 59 | torch.LongTensor([5, 4, 3, 2, 1]), 60 | torch.LongTensor([8, 7, 6, 1]), 61 | torch.LongTensor([9, 1]), 62 | ] 63 | ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete') 64 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) 65 | self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1]) 66 | 67 | data = [ 68 | torch.LongTensor([4, 3, 2, 1]), 69 | torch.LongTensor([5, 1]), 70 | torch.LongTensor([1]), 71 | torch.LongTensor([6, 1]), 72 | ] 73 | ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete') 74 | self.assertEqual(ds[0].tolist(), [4, 3, 2, 1]) 75 | self.assertEqual(ds[1].tolist(), [5, 1, 1]) 76 | self.assertEqual(ds[2].tolist(), [6, 1]) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fairseq import utils 13 | 14 | 15 | class TestUtils(unittest.TestCase): 16 | 17 | def test_convert_padding_direction(self): 18 | pad = 1 19 | left_pad = torch.LongTensor([ 20 | [2, 3, 4, 5, 6], 21 | [1, 7, 8, 9, 10], 22 | [1, 1, 1, 11, 12], 23 | ]) 24 | right_pad = torch.LongTensor([ 25 | [2, 3, 4, 5, 6], 26 | [7, 8, 9, 10, 1], 27 | [11, 12, 1, 1, 1], 28 | ]) 29 | 30 | self.assertAlmostEqual( 31 | right_pad, 32 | utils.convert_padding_direction( 33 | left_pad, 34 | pad, 35 | left_to_right=True, 36 | ), 37 | ) 38 | self.assertAlmostEqual( 39 | left_pad, 40 | utils.convert_padding_direction( 41 | right_pad, 42 | pad, 43 | right_to_left=True, 44 | ), 45 | ) 46 | 47 | def test_make_positions(self): 48 | pad = 1 49 | left_pad_input = torch.LongTensor([ 50 | [9, 9, 9, 9, 9], 51 | [1, 9, 9, 9, 9], 52 | [1, 1, 1, 9, 9], 53 | ]) 54 | left_pad_output = torch.LongTensor([ 55 | [2, 3, 4, 5, 6], 56 | [1, 2, 3, 4, 5], 57 | [1, 1, 1, 2, 3], 58 | ]) 59 | right_pad_input = torch.LongTensor([ 60 | [9, 9, 9, 9, 9], 61 | [9, 9, 9, 9, 1], 62 | [9, 9, 1, 1, 1], 63 | ]) 64 | right_pad_output = torch.LongTensor([ 65 | [2, 3, 4, 5, 6], 66 | [2, 3, 4, 5, 1], 67 | [2, 3, 1, 1, 1], 68 | ]) 69 | 70 | self.assertAlmostEqual( 71 | left_pad_output, 72 | utils.make_positions(left_pad_input, pad, left_pad=True), 73 | ) 74 | self.assertAlmostEqual( 75 | right_pad_output, 76 | utils.make_positions(right_pad_input, pad, left_pad=False), 77 | ) 78 | 79 | def assertAlmostEqual(self, t1, t2): 80 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 81 | self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) 82 | 83 | 84 | if __name__ == '__main__': 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | set -e 3 | 4 | device=0,1,2,3,4,5,6,7 5 | #device=0 6 | 7 | #task=iwslt-de2en 8 | task=wmt-en2de 9 | 10 | # must set this tag 11 | tag= 12 | if [ $task == "wmt-en2de" ]; then 13 | arch=transformer_t2t_wmt_en_de 14 | share_embedding=1 15 | share_decoder_input_output_embed=0 16 | criterion=regularization_label_smoothed_cross_entropy 17 | fp16=0 18 | lr=0.002 19 | warmup=16000 20 | max_tokens=2048 21 | update_freq=4 22 | weight_decay=0.0 23 | keep_last_epochs=10 24 | max_epoch=21 25 | max_update= 26 | data_dir=google 27 | src_lang=en 28 | tgt_lang=de 29 | elif [ $task == "ldc" ]; then 30 | arch=transformer_t2t_wmt_en_de 31 | share_embedding=0 32 | share_decoder_input_output_embed=1 33 | fp16=1 34 | lr=0.002 35 | warmup=8000 36 | max_tokens=2048 37 | update_freq=4 38 | weight_decay=0.0 39 | keep_last_epochs=10 40 | max_epoch=16 41 | max_update= 42 | data_dir=LDC_180W 43 | src_lang=zh 44 | tgt_lang=en 45 | else 46 | echo "unknown task=$task" 47 | exit 48 | fi 49 | 50 | save_dir=checkpoints/$task/$tag 51 | 52 | if [ ! -d $save_dir ]; then 53 | mkdir -p $save_dir 54 | fi 55 | cp ${BASH_SOURCE[0]} $save_dir/train.sh 56 | 57 | gpu_num=`echo "$device" | awk '{split($0,arr,",");print length(arr)}'` 58 | 59 | cmd="python3 -u train.py data-bin/$data_dir 60 | --distributed-world-size $gpu_num -s $src_lang -t $tgt_lang 61 | --arch $arch 62 | --optimizer adam --clip-norm 0.0 63 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates $warmup 64 | --lr $lr --min-lr 1e-09 65 | --weight-decay $weight_decay 66 | --criterion $criterion --label-smoothing 0.1 67 | --max-tokens $max_tokens 68 | --update-freq $update_freq 69 | --no-progress-bar 70 | --log-interval 100 71 | --ddp-backend no_c10d 72 | --save-dir $save_dir 73 | --keep-last-epochs $keep_last_epochs 74 | --tensorboard-logdir $save_dir" 75 | 76 | adam_betas="'(0.9, 0.997)'" 77 | cmd=${cmd}" --adam-betas "${adam_betas} 78 | if [ $share_embedding -eq 1 ]; then 79 | cmd=${cmd}" --share-all-embeddings " 80 | fi 81 | if [ $share_decoder_input_output_embed -eq 1 ]; then 82 | cmd=${cmd}" --share-decoder-input-output-embed " 83 | fi 84 | if [ -n "$max_epoch" ]; then 85 | cmd=${cmd}" --max-epoch "${max_epoch} 86 | fi 87 | if [ -n "$max_update" ]; then 88 | cmd=${cmd}" --max-update "${max_update} 89 | fi 90 | if [ -n "$dropout" ]; then 91 | cmd=${cmd}" --dropout "${dropout} 92 | fi 93 | if [ $fp16 -eq 1 ]; then 94 | cmd=${cmd}" --fp16 " 95 | fi 96 | 97 | #echo $cmd 98 | #eval $cmd 99 | #cmd=$(eval $cmd) 100 | #nohup $cmd exec 1> $save_dir/train.log exec 2>&1 & 101 | #tail -f $save_dir/train.log 102 | 103 | export CUDA_VISIBLE_DEVICES=$device 104 | cmd="nohup "${cmd}" > $save_dir/train.log 2>&1 &" 105 | eval $cmd 106 | tail -f $save_dir/train.log 107 | --------------------------------------------------------------------------------