├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── distributed_train.py ├── docs ├── Makefile ├── command_line_tools.rst ├── conf.py ├── criterions.rst ├── data.rst ├── docutils.conf ├── getting_started.rst ├── index.rst ├── lr_scheduler.rst ├── make.bat ├── models.rst ├── modules.rst ├── optim.rst ├── overview.rst ├── requirements.txt ├── tasks.rst ├── tutorial_classifying_names.rst └── tutorial_simple_lstm.rst ├── eval_lm.py ├── fairseq.gif ├── fairseq ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── multiprocessing_pdb.cpython-36.pyc │ ├── options.cpython-36.pyc │ ├── tokenizer.cpython-36.pyc │ └── utils.cpython-36.pyc ├── bleu.py ├── clib │ └── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ └── label_smoothed_cross_entropy.py ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── dictionary.py │ ├── fairseq_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── monolingual_dataset.py │ └── token_block_dataset.py ├── distributed_utils.py ├── fp16_trainer.py ├── meters.py ├── models │ ├── __init__.py │ ├── composite_encoder.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_self_att.py │ ├── lstm.py │ └── transformer.py ├── modules │ ├── __init__.py │ ├── adaptive_softmax.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── downsampled_multihead_attention.py │ ├── grad_multiply.py │ ├── highway.py │ ├── learned_positional_embedding.py │ ├── linearized_convolution.py │ ├── multihead_attention.py │ ├── scalar_bias.py │ └── sinusoidal_positional_embedding.py ├── multiprocessing_pdb.py ├── optim │ ├── __init__.py │ ├── adagrad.py │ ├── adam.py │ ├── fairseq_optimizer.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ └── sgd.py ├── options.py ├── progress_bar.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── fairseq_task.py │ ├── language_modeling.py │ └── translation.py ├── tokenizer.py ├── trainer.py └── utils.py ├── generate.py ├── interactive.py ├── multiprocessing_train.py ├── preprocess.py ├── requirements.txt ├── score.py ├── scripts ├── average_checkpoints.py ├── build_sym_alignment.py ├── convert_dictionary.lua ├── convert_model.lua └── read_binarized.py ├── setup.py ├── tests ├── test_average_checkpoints.py ├── test_binaries.py ├── test_character_token_embedder.py ├── test_convtbc.py ├── test_dictionary.py ├── test_iterators.py ├── test_label_smoothing.py ├── test_sequence_generator.py ├── test_sequence_scorer.py ├── test_train.py ├── test_utils.py └── utils.py └── train.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch) 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## Coding Style 26 | We try to follow the PEP style guidelines and encourage you to as well. 27 | 28 | ## License 29 | By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed 30 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fairseq software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fairseq software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introducation 2 | Contrastive Attention Mechanism for Abstractive Text Summarization 3 | 4 | fairseq version=0.5.0 5 | 6 | pytorch version=0.4.0 7 | 8 | 9 | # License 10 | fairseq(-py) is BSD-licensed. 11 | The license applies to the pre-trained models as well. 12 | We also provide an additional patent grant. 13 | 14 | # Credits 15 | This is a PyTorch version of 16 | [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence 17 | learning toolkit from Facebook AI Research. The original authors of this 18 | reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam 19 | Gross. 20 | -------------------------------------------------------------------------------- /distributed_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import os 10 | import socket 11 | import subprocess 12 | 13 | from train import main as single_process_main 14 | from fairseq import distributed_utils, options 15 | 16 | 17 | def main(args): 18 | if args.distributed_init_method is None and args.distributed_port > 0: 19 | # We can determine the init method automatically for Slurm. 20 | node_list = os.environ.get('SLURM_JOB_NODELIST') 21 | if node_list is not None: 22 | try: 23 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) 24 | args.distributed_init_method = 'tcp://{host}:{port}'.format( 25 | host=hostnames.split()[0].decode('utf-8'), 26 | port=args.distributed_port) 27 | args.distributed_rank = int(os.environ.get('SLURM_PROCID')) 28 | args.device_id = int(os.environ.get('SLURM_LOCALID')) 29 | except subprocess.CalledProcessError as e: # scontrol failed 30 | raise e 31 | except FileNotFoundError as e: # Slurm is not installed 32 | pass 33 | if args.distributed_init_method is None: 34 | raise ValueError('--distributed-init-method or --distributed-port ' 35 | 'must be specified for distributed training') 36 | 37 | args.distributed_rank = distributed_utils.distributed_init(args) 38 | print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) 39 | single_process_main(args) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = options.get_training_parser() 44 | args = options.parse_args_and_arch(parser) 45 | main(args) 46 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = fairseq 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/command_line_tools.rst: -------------------------------------------------------------------------------- 1 | .. _Command-line Tools: 2 | 3 | Command-line Tools 4 | ================== 5 | 6 | Fairseq provides several command-line tools for training and evaluating models: 7 | 8 | - :ref:`preprocess.py`: Data pre-processing: build vocabularies and binarize training data 9 | - :ref:`train.py`: Train a new model on one or multiple GPUs 10 | - :ref:`generate.py`: Translate pre-processed data with a trained model 11 | - :ref:`interactive.py`: Translate raw text with a trained model 12 | - :ref:`score.py`: BLEU scoring of generated translations against reference translations 13 | - :ref:`eval_lm.py`: Language model evaluation 14 | 15 | 16 | .. _preprocess.py: 17 | 18 | preprocess.py 19 | ~~~~~~~~~~~~~ 20 | .. automodule:: preprocess 21 | 22 | .. argparse:: 23 | :module: preprocess 24 | :func: get_parser 25 | :prog: preprocess.py 26 | 27 | 28 | .. _train.py: 29 | 30 | train.py 31 | ~~~~~~~~ 32 | .. automodule:: train 33 | 34 | .. argparse:: 35 | :module: fairseq.options 36 | :func: get_training_parser 37 | :prog: train.py 38 | 39 | 40 | .. _generate.py: 41 | 42 | generate.py 43 | ~~~~~~~~~~~ 44 | .. automodule:: generate 45 | 46 | .. argparse:: 47 | :module: fairseq.options 48 | :func: get_generation_parser 49 | :prog: generate.py 50 | 51 | 52 | .. _interactive.py: 53 | 54 | interactive.py 55 | ~~~~~~~~~~~~~~ 56 | .. automodule:: interactive 57 | 58 | .. argparse:: 59 | :module: fairseq.options 60 | :func: get_interactive_generation_parser 61 | :prog: interactive.py 62 | 63 | 64 | .. _score.py: 65 | 66 | score.py 67 | ~~~~~~~~ 68 | .. automodule:: score 69 | 70 | .. argparse:: 71 | :module: score 72 | :func: get_parser 73 | :prog: score.py 74 | 75 | 76 | .. _eval_lm.py: 77 | 78 | eval_lm.py 79 | ~~~~~~~~~~ 80 | .. automodule:: eval_lm 81 | 82 | .. argparse:: 83 | :module: fairseq.options 84 | :func: get_eval_lm_parser 85 | :prog: eval_lm.py 86 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # fairseq documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Aug 17 21:45:30 2018. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | 20 | import os 21 | import sys 22 | 23 | # source code directory, relative to this file, for sphinx-autobuild 24 | sys.path.insert(0, os.path.abspath('..')) 25 | 26 | source_suffix = ['.rst'] 27 | 28 | # -- General configuration ------------------------------------------------ 29 | 30 | # If your documentation needs a minimal Sphinx version, state it here. 31 | # 32 | # needs_sphinx = '1.0' 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.intersphinx', 40 | 'sphinx.ext.viewcode', 41 | 'sphinx.ext.napoleon', 42 | 'sphinxarg.ext', 43 | ] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | # General information about the project. 52 | project = 'fairseq' 53 | copyright = '2018, Facebook AI Research (FAIR)' 54 | author = 'Facebook AI Research (FAIR)' 55 | 56 | github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' 57 | 58 | # The version info for the project you're documenting, acts as replacement for 59 | # |version| and |release|, also used in various other places throughout the 60 | # built documents. 61 | # 62 | # The short X.Y version. 63 | version = '0.5.0' 64 | # The full version, including alpha/beta/rc tags. 65 | release = '0.5.0' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = None 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This patterns also effect to html_static_path and html_extra_path 77 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = 'sphinx' 81 | highlight_language = 'python' 82 | 83 | # If true, `todo` and `todoList` produce output, else they produce nothing. 84 | todo_include_todos = False 85 | 86 | 87 | # -- Options for HTML output ---------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | html_theme = 'sphinx_rtd_theme' 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ['_static'] 104 | 105 | html_context = { 106 | 'css_files': [ 107 | '_static/theme_overrides.css', # override wide tables in RTD theme 108 | ], 109 | } 110 | 111 | # Custom sidebar templates, must be a dictionary that maps document names 112 | # to template names. 113 | # 114 | # This is required for the alabaster theme 115 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 116 | #html_sidebars = { 117 | # '**': [ 118 | # 'about.html', 119 | # 'navigation.html', 120 | # 'relations.html', # needs 'show_related': True theme option to display 121 | # 'searchbox.html', 122 | # 'donate.html', 123 | # ] 124 | #} 125 | 126 | 127 | # Example configuration for intersphinx: refer to the Python standard library. 128 | intersphinx_mapping = { 129 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 130 | 'python': ('https://docs.python.org/', None), 131 | 'torch': ('https://pytorch.org/docs/master/', None), 132 | } 133 | -------------------------------------------------------------------------------- /docs/criterions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Criterions: 5 | 6 | Criterions 7 | ========== 8 | 9 | .. automodule:: fairseq.criterions 10 | :members: 11 | .. autoclass:: fairseq.criterions.FairseqCriterion 12 | :members: 13 | :undoc-members: 14 | -------------------------------------------------------------------------------- /docs/data.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.data 5 | 6 | Data Loading and Utilities 7 | ========================== 8 | 9 | .. _datasets: 10 | 11 | Datasets 12 | -------- 13 | 14 | **Datasets** define the data format and provide helpers for creating 15 | mini-batches. 16 | 17 | .. autoclass:: fairseq.data.FairseqDataset 18 | :members: 19 | .. autoclass:: fairseq.data.LanguagePairDataset 20 | :members: 21 | .. autoclass:: fairseq.data.MonolingualDataset 22 | :members: 23 | 24 | 25 | Dictionary 26 | ---------- 27 | 28 | .. autoclass:: fairseq.data.Dictionary 29 | :members: 30 | 31 | 32 | Iterators 33 | --------- 34 | 35 | .. autoclass:: fairseq.data.CountingIterator 36 | :members: 37 | .. autoclass:: fairseq.data.EpochBatchIterator 38 | :members: 39 | .. autoclass:: fairseq.data.ShardedIterator 40 | :members: 41 | -------------------------------------------------------------------------------- /docs/docutils.conf: -------------------------------------------------------------------------------- 1 | [writers] 2 | option-limit=0 3 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fairseq documentation master file, created by 2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pytorch/fairseq 7 | 8 | 9 | fairseq documentation 10 | ===================== 11 | 12 | Fairseq is a sequence modeling toolkit written in `PyTorch 13 | `_ that allows researchers and developers to 14 | train custom models for translation, summarization, language modeling and other 15 | text generation tasks. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Getting Started 20 | 21 | getting_started 22 | command_line_tools 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Extending Fairseq 27 | 28 | overview 29 | tutorial_simple_lstm 30 | tutorial_classifying_names 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | :caption: Library Reference 35 | 36 | tasks 37 | models 38 | criterions 39 | optim 40 | lr_scheduler 41 | data 42 | modules 43 | 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`search` 50 | -------------------------------------------------------------------------------- /docs/lr_scheduler.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Learning Rate Schedulers: 5 | 6 | Learning Rate Schedulers 7 | ======================== 8 | 9 | TODO 10 | 11 | .. automodule:: fairseq.optim.lr_scheduler 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fairseq 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.models 5 | 6 | .. _Models: 7 | 8 | Models 9 | ====== 10 | 11 | A Model defines the neural network's ``forward()`` method and encapsulates all 12 | of the learnable parameters in the network. Each model also provides a set of 13 | named *architectures* that define the precise network configuration (e.g., 14 | embedding dimension, number of layers, etc.). 15 | 16 | Both the model type and architecture are selected via the ``--arch`` 17 | command-line argument. Once selected, a model may expose additional command-line 18 | arguments for further configuration. 19 | 20 | .. note:: 21 | 22 | All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends 23 | :class:`torch.nn.Module`. Thus any fairseq Model can be used as a 24 | stand-alone Module in other PyTorch code. 25 | 26 | 27 | Convolutional Neural Networks (CNN) 28 | ----------------------------------- 29 | 30 | .. module:: fairseq.models.fconv 31 | .. autoclass:: fairseq.models.fconv.FConvModel 32 | :members: 33 | .. autoclass:: fairseq.models.fconv.FConvEncoder 34 | :members: 35 | :undoc-members: 36 | .. autoclass:: fairseq.models.fconv.FConvDecoder 37 | :members: 38 | 39 | 40 | Long Short-Term Memory (LSTM) networks 41 | -------------------------------------- 42 | 43 | .. module:: fairseq.models.lstm 44 | .. autoclass:: fairseq.models.lstm.LSTMModel 45 | :members: 46 | .. autoclass:: fairseq.models.lstm.LSTMEncoder 47 | :members: 48 | .. autoclass:: fairseq.models.lstm.LSTMDecoder 49 | :members: 50 | 51 | 52 | Transformer (self-attention) networks 53 | ------------------------------------- 54 | 55 | .. module:: fairseq.models.transformer 56 | .. autoclass:: fairseq.models.transformer.TransformerModel 57 | :members: 58 | .. autoclass:: fairseq.models.transformer.TransformerEncoder 59 | :members: 60 | .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer 61 | :members: 62 | .. autoclass:: fairseq.models.transformer.TransformerDecoder 63 | :members: 64 | .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer 65 | :members: 66 | 67 | 68 | Adding new models 69 | ----------------- 70 | 71 | .. currentmodule:: fairseq.models 72 | .. autofunction:: fairseq.models.register_model 73 | .. autofunction:: fairseq.models.register_model_architecture 74 | .. autoclass:: fairseq.models.BaseFairseqModel 75 | :members: 76 | :undoc-members: 77 | .. autoclass:: fairseq.models.FairseqModel 78 | :members: 79 | :undoc-members: 80 | .. autoclass:: fairseq.models.FairseqLanguageModel 81 | :members: 82 | :undoc-members: 83 | .. autoclass:: fairseq.models.FairseqEncoder 84 | :members: 85 | .. autoclass:: fairseq.models.CompositeEncoder 86 | :members: 87 | .. autoclass:: fairseq.models.FairseqDecoder 88 | :members: 89 | 90 | 91 | .. _Incremental decoding: 92 | 93 | Incremental decoding 94 | -------------------- 95 | 96 | .. autoclass:: fairseq.models.FairseqIncrementalDecoder 97 | :members: 98 | :undoc-members: 99 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be 5 | helpful when implementing a new :class:`FairseqModel`. 6 | 7 | .. automodule:: fairseq.modules 8 | :members: 9 | :undoc-members: 10 | -------------------------------------------------------------------------------- /docs/optim.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _optimizers: 5 | 6 | Optimizers 7 | ========== 8 | 9 | .. automodule:: fairseq.optim 10 | :members: 11 | -------------------------------------------------------------------------------- /docs/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | Fairseq can be extended through user-supplied `plug-ins 5 | `_. We support five kinds of 6 | plug-ins: 7 | 8 | - :ref:`Models` define the neural network architecture and encapsulate all of the 9 | learnable parameters. 10 | - :ref:`Criterions` compute the loss function given the model outputs and targets. 11 | - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | - :ref:`Optimizers` update the Model parameters based on the gradients. 14 | - :ref:`Learning Rate Schedulers` update the learning rate over the course of 15 | training. 16 | 17 | **Training Flow** 18 | 19 | Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``, 20 | fairseq implements the following high-level training flow:: 21 | 22 | for epoch in range(num_epochs): 23 | itr = task.get_batch_iterator(task.dataset('train')) 24 | for num_updates, batch in enumerate(itr): 25 | loss = criterion(model, batch) 26 | optimizer.backward(loss) 27 | optimizer.step() 28 | lr_scheduler.step_update(num_updates) 29 | lr_scheduler.step(epoch) 30 | 31 | **Registering new plug-ins** 32 | 33 | New plug-ins are *registered* through a set of ``@register`` function 34 | decorators, for example:: 35 | 36 | @register_model('my_lstm') 37 | class MyLSTM(FairseqModel): 38 | (...) 39 | 40 | Once registered, new plug-ins can be used with the existing :ref:`Command-line 41 | Tools`. See the Tutorial sections for more detailed walkthroughs of how to add 42 | new plug-ins. 43 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<2.0 2 | sphinx-argparse 3 | -------------------------------------------------------------------------------- /docs/tasks.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.tasks 5 | 6 | .. _Tasks: 7 | 8 | Tasks 9 | ===== 10 | 11 | Tasks store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | 14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a 15 | task may expose additional command-line arguments for further configuration. 16 | 17 | Example usage:: 18 | 19 | # setup the task (e.g., load dictionaries) 20 | task = fairseq.tasks.setup_task(args) 21 | 22 | # build model and criterion 23 | model = task.build_model(args) 24 | criterion = task.build_criterion(args) 25 | 26 | # load datasets 27 | task.load_dataset('train') 28 | task.load_dataset('valid') 29 | 30 | # iterate over mini-batches of data 31 | batch_itr = task.get_batch_iterator( 32 | task.dataset('train'), max_tokens=4096, 33 | ) 34 | for batch in batch_itr: 35 | # compute the loss 36 | loss, sample_size, logging_output = task.get_loss( 37 | model, criterion, batch, 38 | ) 39 | loss.backward() 40 | 41 | 42 | Translation 43 | ----------- 44 | 45 | .. autoclass:: fairseq.tasks.translation.TranslationTask 46 | 47 | .. _language modeling: 48 | 49 | Language Modeling 50 | ----------------- 51 | 52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask 53 | 54 | 55 | Adding new tasks 56 | ---------------- 57 | 58 | .. autofunction:: fairseq.tasks.register_task 59 | .. autoclass:: fairseq.tasks.FairseqTask 60 | :members: 61 | :undoc-members: 62 | -------------------------------------------------------------------------------- /eval_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | Evaluate the perplexity of a trained language model. 10 | """ 11 | 12 | import numpy as np 13 | import torch 14 | 15 | from fairseq import data, options, progress_bar, tasks, utils 16 | from fairseq.meters import StopwatchMeter, TimeMeter 17 | from fairseq.sequence_scorer import SequenceScorer 18 | 19 | 20 | class WordStat(object): 21 | def __init__(self, word, is_bpe): 22 | self.word = word 23 | self.is_bpe = is_bpe 24 | self.log_prob = 0 25 | self.count = 0 26 | 27 | def add(self, log_prob): 28 | self.log_prob += log_prob 29 | self.count += 1 30 | 31 | def __str__(self): 32 | return '{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob / self.count, self.is_bpe) 33 | 34 | 35 | def main(parsed_args): 36 | assert parsed_args.path is not None, '--path required for evaluation!' 37 | 38 | print(parsed_args) 39 | 40 | use_cuda = torch.cuda.is_available() and not parsed_args.cpu 41 | 42 | task = tasks.setup_task(parsed_args) 43 | 44 | # Load ensemble 45 | print('| loading model(s) from {}'.format(parsed_args.path)) 46 | models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task) 47 | 48 | args.__dict__.update(parsed_args.__dict__) 49 | print(args) 50 | 51 | task.args = args 52 | 53 | # Load dataset splits 54 | task.load_dataset(args.gen_subset) 55 | print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) 56 | 57 | # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) 58 | for model in models: 59 | model.make_generation_fast_() 60 | if args.fp16: 61 | model.half() 62 | 63 | assert len(models) > 0 64 | 65 | itr = task.get_batch_iterator( 66 | dataset=task.dataset(args.gen_subset), 67 | max_tokens=args.max_tokens or 36000, 68 | max_sentences=args.max_sentences, 69 | max_positions=utils.resolve_max_positions(*[ 70 | model.max_positions() for model in models 71 | ]), 72 | num_shards=args.num_shards, 73 | shard_id=args.shard_id, 74 | ignore_invalid_inputs=True, 75 | ).next_epoch_itr(shuffle=False) 76 | 77 | gen_timer = StopwatchMeter() 78 | scorer = SequenceScorer(models, task.target_dictionary) 79 | if use_cuda: 80 | scorer.cuda() 81 | 82 | score_sum = 0. 83 | count = 0 84 | 85 | if args.remove_bpe is not None: 86 | bpe_cont = args.remove_bpe.rstrip() 87 | bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) 88 | bpe_len = len(bpe_cont) 89 | else: 90 | bpe_toks = None 91 | bpe_len = 0 92 | 93 | word_stats = dict() 94 | 95 | with progress_bar.build_progress_bar(args, itr) as t: 96 | results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) 97 | wps_meter = TimeMeter() 98 | for _, src_tokens, __, hypos in results: 99 | for hypo in hypos: 100 | pos_scores = hypo['positional_scores'] 101 | 102 | skipped_toks = 0 103 | if bpe_toks is not None: 104 | for i in range(len(hypo['tokens']) - 1): 105 | if hypo['tokens'][i].item() in bpe_toks: 106 | skipped_toks += 1 107 | pos_scores[i + 1] += pos_scores[i] 108 | pos_scores[i] = 0 109 | 110 | inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) 111 | if inf_scores.any(): 112 | print('| Skipping tokens with inf scores:', 113 | task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) 114 | pos_scores = pos_scores[(~inf_scores).nonzero()] 115 | score_sum += utils.item(pos_scores.sum()) 116 | count += pos_scores.numel() - skipped_toks 117 | 118 | if args.output_word_probs or args.output_word_stats: 119 | w = '' 120 | word_prob = [] 121 | is_bpe = False 122 | for i in range(len(hypo['tokens'])): 123 | w_ind = hypo['tokens'][i].item() 124 | w += task.dictionary[w_ind] 125 | if bpe_toks is not None and w_ind in bpe_toks: 126 | w = w[:-bpe_len] 127 | is_bpe = True 128 | else: 129 | word_prob.append((w, pos_scores[i].item())) 130 | word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item()) 131 | is_bpe = False 132 | w = '' 133 | if args.output_word_probs: 134 | print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) 135 | 136 | wps_meter.update(src_tokens.size(0)) 137 | t.log({'wps': round(wps_meter.avg)}) 138 | 139 | avg_nll_loss = -score_sum / count 140 | print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) 141 | print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) 142 | 143 | if args.output_word_stats: 144 | for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): 145 | print(ws) 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = options.get_eval_lm_parser() 150 | args = options.parse_args_and_arch(parser) 151 | main(args) 152 | -------------------------------------------------------------------------------- /fairseq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/travel-go/Abstractive-Text-Summarization/915621368f9dfd84bf4eb517665026d0a31cb408/fairseq.gif -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .multiprocessing_pdb import pdb 9 | 10 | __all__ = ['pdb'] 11 | 12 | import fairseq.criterions 13 | import fairseq.models 14 | import fairseq.modules 15 | import fairseq.optim 16 | import fairseq.optim.lr_scheduler 17 | import fairseq.tasks 18 | -------------------------------------------------------------------------------- /fairseq/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/travel-go/Abstractive-Text-Summarization/915621368f9dfd84bf4eb517665026d0a31cb408/fairseq/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/multiprocessing_pdb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/travel-go/Abstractive-Text-Summarization/915621368f9dfd84bf4eb517665026d0a31cb408/fairseq/__pycache__/multiprocessing_pdb.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/travel-go/Abstractive-Text-Summarization/915621368f9dfd84bf4eb517665026d0a31cb408/fairseq/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/travel-go/Abstractive-Text-Summarization/915621368f9dfd84bf4eb517665026d0a31cb408/fairseq/__pycache__/tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/travel-go/Abstractive-Text-Summarization/915621368f9dfd84bf4eb517665026d0a31cb408/fairseq/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import ctypes 9 | import math 10 | import torch 11 | 12 | try: 13 | from fairseq import libbleu 14 | except ImportError as e: 15 | import sys 16 | sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py install`\n') 17 | raise e 18 | 19 | 20 | C = ctypes.cdll.LoadLibrary(libbleu.__file__) 21 | 22 | 23 | class BleuStat(ctypes.Structure): 24 | _fields_ = [ 25 | ('reflen', ctypes.c_size_t), 26 | ('predlen', ctypes.c_size_t), 27 | ('match1', ctypes.c_size_t), 28 | ('count1', ctypes.c_size_t), 29 | ('match2', ctypes.c_size_t), 30 | ('count2', ctypes.c_size_t), 31 | ('match3', ctypes.c_size_t), 32 | ('count3', ctypes.c_size_t), 33 | ('match4', ctypes.c_size_t), 34 | ('count4', ctypes.c_size_t), 35 | ] 36 | 37 | 38 | class Scorer(object): 39 | def __init__(self, pad, eos, unk): 40 | self.stat = BleuStat() 41 | self.pad = pad 42 | self.eos = eos 43 | self.unk = unk 44 | self.reset() 45 | 46 | def reset(self, one_init=False): 47 | if one_init: 48 | C.bleu_one_init(ctypes.byref(self.stat)) 49 | else: 50 | C.bleu_zero_init(ctypes.byref(self.stat)) 51 | 52 | def add(self, ref, pred): 53 | if not isinstance(ref, torch.IntTensor): 54 | raise TypeError('ref must be a torch.IntTensor (got {})' 55 | .format(type(ref))) 56 | if not isinstance(pred, torch.IntTensor): 57 | raise TypeError('pred must be a torch.IntTensor(got {})' 58 | .format(type(pred))) 59 | 60 | # don't match unknown words 61 | rref = ref.clone() 62 | assert not rref.lt(0).any() 63 | rref[rref.eq(self.unk)] = -999 64 | 65 | rref = rref.contiguous().view(-1) 66 | pred = pred.contiguous().view(-1) 67 | 68 | C.bleu_add( 69 | ctypes.byref(self.stat), 70 | ctypes.c_size_t(rref.size(0)), 71 | ctypes.c_void_p(rref.data_ptr()), 72 | ctypes.c_size_t(pred.size(0)), 73 | ctypes.c_void_p(pred.data_ptr()), 74 | ctypes.c_int(self.pad), 75 | ctypes.c_int(self.eos)) 76 | 77 | def score(self, order=4): 78 | psum = sum(math.log(p) if p > 0 else float('-Inf') 79 | for p in self.precision()[:order]) 80 | return self.brevity() * math.exp(psum / order) * 100 81 | 82 | def precision(self): 83 | def ratio(a, b): 84 | return a / b if b > 0 else 0 85 | 86 | return [ 87 | ratio(self.stat.match1, self.stat.count1), 88 | ratio(self.stat.match2, self.stat.count2), 89 | ratio(self.stat.match3, self.stat.count3), 90 | ratio(self.stat.match4, self.stat.count4), 91 | ] 92 | 93 | def brevity(self): 94 | r = self.stat.reflen / self.stat.predlen 95 | return min(1, math.exp(1 - r)) 96 | 97 | def result_string(self, order=4): 98 | assert order <= 4, "BLEU scores for order > 4 aren't supported" 99 | fmt = 'BLEU{} = {:2.2f}, {:2.1f}' 100 | for _ in range(1, order): 101 | fmt += '/{:2.1f}' 102 | fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})' 103 | bleup = [p * 100 for p in self.precision()[:order]] 104 | return fmt.format(order, self.score(order=order), *bleup, 105 | self.brevity(), self.stat.predlen/self.stat.reflen, 106 | self.stat.predlen, self.stat.reflen) 107 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/libbleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, size_t n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += predlen; 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 0; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 0; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_criterion import FairseqCriterion 12 | 13 | 14 | CRITERION_REGISTRY = {} 15 | CRITERION_CLASS_NAMES = set() 16 | 17 | 18 | def build_criterion(args, task): 19 | return CRITERION_REGISTRY[args.criterion](args, task) 20 | 21 | 22 | def register_criterion(name): 23 | """Decorator to register a new criterion.""" 24 | 25 | def register_criterion_cls(cls): 26 | if name in CRITERION_REGISTRY: 27 | raise ValueError('Cannot register duplicate criterion ({})'.format(name)) 28 | if not issubclass(cls, FairseqCriterion): 29 | raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__)) 30 | if cls.__name__ in CRITERION_CLASS_NAMES: 31 | # We use the criterion class name as a unique identifier in 32 | # checkpoints, so all criterions must have unique class names. 33 | raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__)) 34 | CRITERION_REGISTRY[name] = cls 35 | CRITERION_CLASS_NAMES.add(cls.__name__) 36 | return cls 37 | 38 | return register_criterion_cls 39 | 40 | 41 | # automatically import any Python files in the criterions/ directory 42 | for file in os.listdir(os.path.dirname(__file__)): 43 | if file.endswith('.py') and not file.startswith('_'): 44 | module = file[:file.find('.py')] 45 | importlib.import_module('fairseq.criterions.' + module) 46 | -------------------------------------------------------------------------------- /fairseq/criterions/adaptive_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import math 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('adaptive_loss') 17 | class AdaptiveLoss(FairseqCriterion): 18 | """This is an implementation of the loss function accompanying the adaptive softmax approximation for 19 | graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" 20 | (http://arxiv.org/abs/1609.04309).""" 21 | 22 | def __init__(self, args, task): 23 | super().__init__(args, task) 24 | 25 | def forward(self, model, sample, reduce=True): 26 | """Compute the loss for the given sample. 27 | 28 | Returns a tuple with three elements: 29 | 1) the loss 30 | 2) the sample size, which is used as the denominator for the gradient 31 | 3) logging outputs to display while training 32 | """ 33 | 34 | assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None 35 | adaptive_softmax = model.decoder.adaptive_softmax 36 | 37 | net_output = model(**sample['net_input']) 38 | target = model.get_targets(sample, net_output).view(-1) 39 | 40 | bsz = target.size(0) 41 | 42 | logits, target = adaptive_softmax(net_output[0], target) 43 | assert len(target) == len(logits) 44 | 45 | loss = net_output[0].new(1 if reduce else bsz).zero_() 46 | 47 | for i in range(len(target)): 48 | if target[i] is not None: 49 | assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1)) 50 | loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx, 51 | reduce=reduce) 52 | 53 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 54 | logging_output = { 55 | 'loss': utils.item(loss.data) if reduce else loss.data, 56 | 'ntokens': sample['ntokens'], 57 | 'sample_size': sample_size, 58 | } 59 | return loss, sample_size, logging_output 60 | 61 | @staticmethod 62 | def aggregate_logging_outputs(logging_outputs): 63 | """Aggregate logging outputs from data parallel training.""" 64 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 65 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 66 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 67 | agg_output = { 68 | 'loss': loss_sum / sample_size / math.log(2), 69 | 'sample_size': sample_size, 70 | } 71 | if sample_size != ntokens: 72 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 73 | return agg_output 74 | -------------------------------------------------------------------------------- /fairseq/criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('cross_entropy') 17 | class CrossEntropyCriterion(FairseqCriterion): 18 | 19 | def __init__(self, args, task): 20 | super().__init__(args, task) 21 | 22 | def forward(self, model, sample, reduce=True): 23 | """Compute the loss for the given sample. 24 | 25 | Returns a tuple with three elements: 26 | 1) the loss 27 | 2) the sample size, which is used as the denominator for the gradient 28 | 3) logging outputs to display while training 29 | """ 30 | net_output = model(**sample['net_input']) 31 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 32 | lprobs = lprobs.view(-1, lprobs.size(-1)) 33 | target = model.get_targets(sample, net_output).view(-1) 34 | loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, 35 | reduce=reduce) 36 | 37 | re_lprobs = model.get_re_normalized_probs(net_output, log_probs=True) 38 | re_lprobs = re_lprobs.view(-1, re_lprobs.size(-1)) 39 | target = model.get_targets(sample, net_output).view(-1, 1) 40 | non_pad_mask = target.ne(self.padding_idx) 41 | re_nll_loss = -re_lprobs.gather(dim=-1, index=target)[non_pad_mask] 42 | re_nll_loss = re_nll_loss.sum() 43 | 44 | loss = loss+0.15*re_nll_loss 45 | 46 | 47 | 48 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 49 | logging_output = { 50 | 'loss': utils.item(loss.data) if reduce else loss.data, 51 | 'ntokens': sample['ntokens'], 52 | 'sample_size': sample_size, 53 | } 54 | 55 | 56 | 57 | return loss, sample_size, logging_output 58 | 59 | @staticmethod 60 | def aggregate_logging_outputs(logging_outputs): 61 | """Aggregate logging outputs from data parallel training.""" 62 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 63 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 64 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 65 | agg_output = { 66 | 'loss': loss_sum / sample_size / math.log(2), 67 | 'sample_size': sample_size, 68 | } 69 | if sample_size != ntokens: 70 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 71 | return agg_output 72 | -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.nn.modules.loss import _Loss 9 | 10 | 11 | class FairseqCriterion(_Loss): 12 | 13 | def __init__(self, args, task): 14 | super().__init__() 15 | self.args = args 16 | self.padding_idx = task.target_dictionary.pad() 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add criterion-specific arguments to the parser.""" 21 | pass 22 | 23 | def forward(self, model, sample, reduce=True): 24 | """Compute the loss for the given sample. 25 | 26 | Returns a tuple with three elements: 27 | 1) the loss 28 | 2) the sample size, which is used as the denominator for the gradient 29 | 3) logging outputs to display while training 30 | """ 31 | raise NotImplementedError 32 | 33 | @staticmethod 34 | def aggregate_logging_outputs(logging_outputs): 35 | """Aggregate logging outputs from data parallel training.""" 36 | raise NotImplementedError 37 | 38 | @staticmethod 39 | def grad_denom(sample_sizes): 40 | """Compute the gradient denominator for a set of sample sizes.""" 41 | return sum(sample_sizes) 42 | -------------------------------------------------------------------------------- /fairseq/criterions/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch 10 | from fairseq import utils 11 | 12 | from . import FairseqCriterion, register_criterion 13 | 14 | 15 | @register_criterion('label_smoothed_cross_entropy') 16 | class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): 17 | 18 | def __init__(self, args, task): 19 | super().__init__(args, task) 20 | self.eps = args.label_smoothing 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add criterion-specific arguments to the parser.""" 25 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 26 | help='epsilon for label smoothing, 0 means no label smoothing') 27 | 28 | def forward(self, model, sample, reduce=True): 29 | """Compute the loss for the given sample. 30 | 31 | Returns a tuple with three elements: 32 | 1) the loss 33 | 2) the sample size, which is used as the denominator for the gradient 34 | 3) logging outputs to display while training 35 | """ 36 | net_output = model(**sample['net_input']) 37 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 38 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 39 | logging_output = { 40 | 'loss': utils.item(loss.data) if reduce else loss.data, 41 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 42 | 'ntokens': sample['ntokens'], 43 | 'sample_size': sample_size, 44 | } 45 | return loss, sample_size, logging_output 46 | 47 | def compute_loss(self, model, net_output, sample, reduce=True): 48 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 49 | 50 | lprobs = lprobs.view(-1, lprobs.size(-1)) 51 | target = model.get_targets(sample, net_output).view(-1, 1) 52 | non_pad_mask = target.ne(self.padding_idx) 53 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 54 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 55 | if reduce: 56 | nll_loss = nll_loss.sum() 57 | smooth_loss = smooth_loss.sum() 58 | eps_i = self.eps / lprobs.size(-1) 59 | loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss 60 | 61 | re_lprobs = model.get_re_normalized_probs(net_output, log_probs=True) 62 | re_lprobs = re_lprobs.view(-1, re_lprobs.size(-1)) 63 | re_nll_loss = -re_lprobs.gather(dim=-1, index=target)[non_pad_mask] 64 | re_smooth_loss = -re_lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 65 | 66 | 67 | if reduce: 68 | re_nll_loss = re_nll_loss.sum() 69 | re_smooth_loss = re_smooth_loss.sum() 70 | eps_i = self.eps / re_lprobs.size(-1) 71 | re_loss = (1. - self.eps) * re_nll_loss + eps_i * re_smooth_loss 72 | 73 | loss = loss+ 0.2*re_loss 74 | nll_loss = nll_loss+0.2*re_nll_loss 75 | 76 | return loss, nll_loss 77 | 78 | 79 | @staticmethod 80 | def aggregate_logging_outputs(logging_outputs): 81 | """Aggregate logging outputs from data parallel training.""" 82 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 83 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 84 | return { 85 | 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 86 | 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 87 | 'sample_size': sample_size, 88 | } 89 | -------------------------------------------------------------------------------- /fairseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .dictionary import Dictionary 9 | from .fairseq_dataset import FairseqDataset 10 | from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset 11 | from .language_pair_dataset import LanguagePairDataset 12 | from .monolingual_dataset import MonolingualDataset 13 | from .token_block_dataset import TokenBlockDataset 14 | 15 | from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator 16 | 17 | __all__ = [ 18 | 'CountingIterator', 19 | 'Dictionary', 20 | 'EpochBatchIterator', 21 | 'FairseqDataset', 22 | 'IndexedDataset', 23 | 'IndexedInMemoryDataset', 24 | 'IndexedRawTextDataset', 25 | 'LanguagePairDataset', 26 | 'MonolingualDataset', 27 | 'TokenBlockDataset', 28 | 'ShardedIterator', 29 | ] 30 | -------------------------------------------------------------------------------- /fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.utils.data 9 | 10 | from fairseq.data import data_utils 11 | 12 | 13 | class FairseqDataset(torch.utils.data.Dataset): 14 | """A dataset that provides helpers for batching.""" 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def collater(self, samples): 23 | """Merge a list of samples to form a mini-batch. 24 | 25 | Args: 26 | samples (List[int]): sample indices to collate 27 | 28 | Returns: 29 | dict: a mini-batch suitable for forwarding with a Model 30 | """ 31 | raise NotImplementedError 32 | 33 | def get_dummy_batch(self, num_tokens, max_positions): 34 | """Return a dummy batch with a given number of tokens.""" 35 | raise NotImplementedError 36 | 37 | def num_tokens(self, index): 38 | """Return the number of tokens in a sample. This value is used to 39 | enforce ``--max-tokens`` during batching.""" 40 | raise NotImplementedError 41 | 42 | def size(self, index): 43 | """Return an example's size as a float or tuple. This value is used when 44 | filtering a dataset with ``--max-positions``.""" 45 | raise NotImplementedError 46 | 47 | def ordered_indices(self): 48 | """Return an ordered list of indices. Batches will be constructed based 49 | on this order.""" 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /fairseq/data/monolingual_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from . import data_utils, FairseqDataset 12 | 13 | 14 | def collate(samples, pad_idx, eos_idx): 15 | if len(samples) == 0: 16 | return {} 17 | 18 | def merge(key): 19 | return data_utils.collate_tokens( 20 | [s[key] for s in samples], pad_idx, eos_idx, left_pad=False, 21 | ) 22 | 23 | return { 24 | 'id': torch.LongTensor([s['id'] for s in samples]), 25 | 'ntokens': sum(len(s['target']) for s in samples), 26 | 'net_input': { 27 | 'src_tokens': merge('source'), 28 | 'src_lengths': torch.LongTensor([ 29 | s['source'].numel() for s in samples 30 | ]), 31 | }, 32 | 'target': merge('target'), 33 | } 34 | 35 | 36 | class MonolingualDataset(FairseqDataset): 37 | """ 38 | A wrapper around torch.utils.data.Dataset for monolingual data. 39 | 40 | Args: 41 | dataset (torch.utils.data.Dataset): dataset to wrap 42 | sizes (List[int]): sentence lengths 43 | vocab (~fairseq.data.Dictionary): vocabulary 44 | shuffle (bool, optional): shuffle the elements before batching. 45 | Default: ``True`` 46 | """ 47 | 48 | def __init__(self, dataset, sizes, vocab, shuffle=True): 49 | self.dataset = dataset 50 | self.sizes = np.array(sizes) 51 | self.vocab = vocab 52 | self.shuffle = shuffle 53 | 54 | def __getitem__(self, index): 55 | source, target = self.dataset[index] 56 | return {'id': index, 'source': source, 'target': target} 57 | 58 | def __len__(self): 59 | return len(self.dataset) 60 | 61 | def collater(self, samples): 62 | """Merge a list of samples to form a mini-batch. 63 | 64 | Args: 65 | samples (List[dict]): samples to collate 66 | 67 | Returns: 68 | dict: a mini-batch with the following keys: 69 | 70 | - `id` (LongTensor): example IDs in the original input order 71 | - `ntokens` (int): total number of tokens in the batch 72 | - `net_input` (dict): the input to the Model, containing keys: 73 | 74 | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in 75 | the source sentence of shape `(bsz, src_len)`. Padding will 76 | appear on the right. 77 | 78 | - `target` (LongTensor): a padded 2D Tensor of tokens in the 79 | target sentence of shape `(bsz, tgt_len)`. Padding will appear 80 | on the right. 81 | """ 82 | return collate(samples, self.vocab.pad(), self.vocab.eos()) 83 | 84 | def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128): 85 | """Return a dummy batch with a given number of tokens.""" 86 | if isinstance(max_positions, float) or isinstance(max_positions, int): 87 | tgt_len = min(tgt_len, max_positions) 88 | bsz = num_tokens // tgt_len 89 | target = self.vocab.dummy_sentence(tgt_len + 1) 90 | source, target = target[:-1], target[1:] 91 | return self.collater([ 92 | {'id': i, 'source': source, 'target': target} 93 | for i in range(bsz) 94 | ]) 95 | 96 | def num_tokens(self, index): 97 | """Return the number of tokens in a sample. This value is used to 98 | enforce ``--max-tokens`` during batching.""" 99 | return self.sizes[index] 100 | 101 | def size(self, index): 102 | """Return an example's size as a float or tuple. This value is used when 103 | filtering a dataset with ``--max-positions``.""" 104 | return self.sizes[index] 105 | 106 | def ordered_indices(self): 107 | """Return an ordered list of indices. Batches will be constructed based 108 | on this order.""" 109 | if self.shuffle: 110 | order = [np.random.permutation(len(self))] 111 | else: 112 | order = [np.arange(len(self))] 113 | order.append(np.flip(self.sizes, 0)) 114 | return np.lexsort(order) 115 | -------------------------------------------------------------------------------- /fairseq/data/token_block_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | class TokenBlockDataset(torch.utils.data.Dataset): 15 | """Break a 1d tensor of tokens into blocks. 16 | 17 | The blocks are fetched from the original tensor so no additional memory is allocated. 18 | 19 | Args: 20 | tokens: 1d tensor of tokens to break into blocks 21 | sizes: sentence lengths (required for 'complete' and 'eos') 22 | block_size: maximum block size (ignored in 'eos' break mode) 23 | break_mode: Mode used for breaking tokens. Values can be one of: 24 | - 'none': break tokens into equally sized blocks (up to block_size) 25 | - 'complete': break tokens into blocks (up to block_size) such that 26 | blocks contains complete sentences, although block_size may be 27 | exceeded if some sentences exceed block_size 28 | - 'eos': each block contains one sentence (block_size is ignored) 29 | include_targets: return next tokens as targets 30 | """ 31 | 32 | def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False): 33 | super().__init__() 34 | 35 | self.tokens = tokens 36 | self.total_size = len(tokens) 37 | self.include_targets = include_targets 38 | self.slice_indices = [] 39 | 40 | if break_mode is None or break_mode == 'none': 41 | length = math.ceil(len(tokens) / block_size) 42 | 43 | def block_at(i): 44 | start = i * block_size 45 | end = min(start + block_size, len(tokens)) 46 | return (start, end) 47 | 48 | self.slice_indices = [block_at(i) for i in range(length)] 49 | elif break_mode == 'complete': 50 | assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens)) 51 | tok_idx = 0 52 | sz_idx = 0 53 | curr_size = 0 54 | while sz_idx < len(sizes): 55 | if curr_size + sizes[sz_idx] <= block_size or curr_size == 0: 56 | curr_size += sizes[sz_idx] 57 | sz_idx += 1 58 | else: 59 | self.slice_indices.append((tok_idx, tok_idx + curr_size)) 60 | tok_idx += curr_size 61 | curr_size = 0 62 | if curr_size > 0: 63 | self.slice_indices.append((tok_idx, tok_idx + curr_size)) 64 | elif break_mode == 'eos': 65 | assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens)) 66 | curr = 0 67 | for sz in sizes: 68 | # skip samples with just 1 example (which would be just the eos token) 69 | if sz > 1: 70 | self.slice_indices.append((curr, curr + sz)) 71 | curr += sz 72 | else: 73 | raise ValueError('Invalid break_mode: ' + break_mode) 74 | 75 | self.sizes = np.array([e - s for s, e in self.slice_indices]) 76 | 77 | def __getitem__(self, index): 78 | s, e = self.slice_indices[index] 79 | 80 | item = torch.LongTensor(self.tokens[s:e]) 81 | 82 | if self.include_targets: 83 | # target is the sentence, for source, rotate item one token to the left (would start with eos) 84 | if s == 0: 85 | source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]]) 86 | else: 87 | source = self.tokens[s - 1:e - 1] 88 | 89 | return torch.LongTensor(source), item 90 | return item 91 | 92 | def __len__(self): 93 | return len(self.slice_indices) 94 | -------------------------------------------------------------------------------- /fairseq/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import pickle 9 | 10 | import torch.distributed 11 | 12 | from fairseq import utils 13 | 14 | 15 | def is_master(args): 16 | return args.distributed_rank == 0 17 | 18 | 19 | def distributed_init(args): 20 | if args.distributed_world_size == 1: 21 | raise ValueError('Cannot initialize distributed with distributed_world_size=1') 22 | 23 | print('| distributed init (rank {}): {}'.format( 24 | args.distributed_rank, args.distributed_init_method), flush=True) 25 | if args.distributed_init_method.startswith('tcp://'): 26 | torch.distributed.init_process_group( 27 | backend=args.distributed_backend, init_method=args.distributed_init_method, 28 | world_size=args.distributed_world_size, rank=args.distributed_rank) 29 | else: 30 | torch.distributed.init_process_group( 31 | backend=args.distributed_backend, init_method=args.distributed_init_method, 32 | world_size=args.distributed_world_size) 33 | 34 | args.distributed_rank = torch.distributed.get_rank() 35 | if not is_master(args): 36 | suppress_output() 37 | 38 | return args.distributed_rank 39 | 40 | 41 | def suppress_output(): 42 | """Suppress printing on the current device. Force printing with `force=True`.""" 43 | import builtins as __builtin__ 44 | builtin_print = __builtin__.print 45 | 46 | def print(*args, **kwargs): 47 | if 'force' in kwargs: 48 | force = kwargs.pop('force') 49 | if force: 50 | builtin_print(*args, **kwargs) 51 | 52 | __builtin__.print = print 53 | 54 | 55 | def all_gather_list(data, max_size=16384): 56 | """Gathers arbitrary data from all nodes into a list.""" 57 | world_size = torch.distributed.get_world_size() 58 | if not hasattr(all_gather_list, '_in_buffer') or \ 59 | max_size != all_gather_list._in_buffer.size(): 60 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 61 | all_gather_list._out_buffers = [ 62 | torch.cuda.ByteTensor(max_size) 63 | for i in range(world_size) 64 | ] 65 | in_buffer = all_gather_list._in_buffer 66 | out_buffers = all_gather_list._out_buffers 67 | 68 | enc = pickle.dumps(data) 69 | enc_size = len(enc) 70 | if enc_size + 2 > max_size: 71 | raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) 72 | assert max_size < 255*256 73 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 74 | in_buffer[1] = enc_size % 255 75 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 76 | 77 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 78 | 79 | result = [] 80 | for i in range(world_size): 81 | out_buffer = out_buffers[i] 82 | size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) 83 | result.append( 84 | pickle.loads(bytes(out_buffer[2:size+2].tolist())) 85 | ) 86 | return result 87 | -------------------------------------------------------------------------------- /fairseq/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import time 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | class TimeMeter(object): 30 | """Computes the average occurrence of some event per second""" 31 | def __init__(self, init=0): 32 | self.reset(init) 33 | 34 | def reset(self, init=0): 35 | self.init = init 36 | self.start = time.time() 37 | self.n = 0 38 | 39 | def update(self, val=1): 40 | self.n += val 41 | 42 | @property 43 | def avg(self): 44 | return self.n / self.elapsed_time 45 | 46 | @property 47 | def elapsed_time(self): 48 | return self.init + (time.time() - self.start) 49 | 50 | 51 | class StopwatchMeter(object): 52 | """Computes the sum/avg duration of some event in seconds""" 53 | def __init__(self): 54 | self.reset() 55 | 56 | def start(self): 57 | self.start_time = time.time() 58 | 59 | def stop(self, n=1): 60 | if self.start_time is not None: 61 | delta = time.time() - self.start_time 62 | self.sum += delta 63 | self.n += n 64 | self.start_time = None 65 | 66 | def reset(self): 67 | self.sum = 0 68 | self.n = 0 69 | self.start_time = None 70 | 71 | @property 72 | def avg(self): 73 | return self.sum / self.n 74 | -------------------------------------------------------------------------------- /fairseq/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import importlib 10 | import os 11 | 12 | from .fairseq_decoder import FairseqDecoder # noqa: F401 13 | from .fairseq_encoder import FairseqEncoder # noqa: F401 14 | from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 15 | from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 16 | 17 | from .composite_encoder import CompositeEncoder # noqa: F401 18 | 19 | 20 | MODEL_REGISTRY = {} 21 | ARCH_MODEL_REGISTRY = {} 22 | ARCH_MODEL_INV_REGISTRY = {} 23 | ARCH_CONFIG_REGISTRY = {} 24 | 25 | 26 | def build_model(args, task): 27 | return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) 28 | 29 | 30 | def register_model(name): 31 | """ 32 | New model types can be added to fairseq with the :func:`register_model` 33 | function decorator. 34 | 35 | For example:: 36 | 37 | @register_model('lstm') 38 | class LSTM(FairseqModel): 39 | (...) 40 | 41 | .. note:: All models must implement the :class:`BaseFairseqModel` interface. 42 | Typically you will extend :class:`FairseqModel` for sequence-to-sequence 43 | tasks or :class:`FairseqLanguageModel` for language modeling tasks. 44 | 45 | Args: 46 | name (str): the name of the model 47 | """ 48 | 49 | def register_model_cls(cls): 50 | if name in MODEL_REGISTRY: 51 | raise ValueError('Cannot register duplicate model ({})'.format(name)) 52 | if not issubclass(cls, BaseFairseqModel): 53 | raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__)) 54 | MODEL_REGISTRY[name] = cls 55 | return cls 56 | 57 | return register_model_cls 58 | 59 | 60 | def register_model_architecture(model_name, arch_name): 61 | """ 62 | New model architectures can be added to fairseq with the 63 | :func:`register_model_architecture` function decorator. After registration, 64 | model architectures can be selected with the ``--arch`` command-line 65 | argument. 66 | 67 | For example:: 68 | 69 | @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') 70 | def lstm_luong_wmt_en_de(args): 71 | args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) 72 | (...) 73 | 74 | The decorated function should take a single argument *args*, which is a 75 | :class:`argparse.Namespace` of arguments parsed from the command-line. The 76 | decorated function should modify these arguments in-place to match the 77 | desired architecture. 78 | 79 | Args: 80 | model_name (str): the name of the Model (Model must already be 81 | registered) 82 | arch_name (str): the name of the model architecture (``--arch``) 83 | """ 84 | 85 | def register_model_arch_fn(fn): 86 | if model_name not in MODEL_REGISTRY: 87 | raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name)) 88 | if arch_name in ARCH_MODEL_REGISTRY: 89 | raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name)) 90 | if not callable(fn): 91 | raise ValueError('Model architecture must be callable ({})'.format(arch_name)) 92 | ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] 93 | ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) 94 | ARCH_CONFIG_REGISTRY[arch_name] = fn 95 | return fn 96 | 97 | return register_model_arch_fn 98 | 99 | 100 | # automatically import any Python files in the models/ directory 101 | for file in os.listdir(os.path.dirname(__file__)): 102 | if file.endswith('.py') and not file.startswith('_'): 103 | model_name = file[:file.find('.py')] 104 | module = importlib.import_module('fairseq.models.' + model_name) 105 | 106 | # extra `model_parser` for sphinx 107 | if model_name in MODEL_REGISTRY: 108 | parser = argparse.ArgumentParser(add_help=False) 109 | group_archs = parser.add_argument_group('Named architectures') 110 | group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name]) 111 | group_args = parser.add_argument_group('Additional command-line arguments') 112 | MODEL_REGISTRY[model_name].add_args(group_args) 113 | globals()[model_name + '_parser'] = parser 114 | -------------------------------------------------------------------------------- /fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqEncoder 9 | 10 | 11 | class CompositeEncoder(FairseqEncoder): 12 | """ 13 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 14 | 15 | We run forward on each encoder and return a dictionary of outputs. The first 16 | encoder's dictionary is used for initialization. 17 | 18 | Args: 19 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 20 | """ 21 | 22 | def __init__(self, encoders): 23 | super().__init__(next(iter(encoders.values())).dictionary) 24 | self.encoders = encoders 25 | for key in self.encoders: 26 | self.add_module(key, self.encoders[key]) 27 | 28 | def forward(self, src_tokens, src_lengths): 29 | """ 30 | Args: 31 | src_tokens (LongTensor): tokens in the source language of shape 32 | `(batch, src_len)` 33 | src_lengths (LongTensor): lengths of each source sentence of shape 34 | `(batch)` 35 | 36 | Returns: 37 | dict: 38 | the outputs from each Encoder 39 | """ 40 | encoder_out = {} 41 | for key in self.encoders: 42 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 43 | return encoder_out 44 | 45 | def reorder_encoder_out(self, encoder_out, new_order): 46 | """Reorder encoder output according to new_order.""" 47 | for key in self.encoders: 48 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) 49 | return encoder_out 50 | 51 | def max_positions(self): 52 | return min([self.encoders[key].max_positions() for key in self.encoders]) 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | for key in self.encoders: 56 | self.encoders[key].upgrade_state_dict(state_dict) 57 | return state_dict 58 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | class FairseqDecoder(nn.Module): 13 | """Base class for decoders.""" 14 | 15 | def __init__(self, dictionary): 16 | super().__init__() 17 | self.dictionary = dictionary 18 | 19 | def forward(self, prev_output_tokens, encoder_out): 20 | """ 21 | Args: 22 | prev_output_tokens (LongTensor): previous decoder outputs of shape 23 | `(batch, tgt_len)`, for input feeding/teacher forcing 24 | encoder_out (Tensor, optional): output from the encoder, used for 25 | encoder-side attention 26 | 27 | Returns: 28 | tuple: 29 | - the last decoder layer's output of shape 30 | `(batch, tgt_len, vocab)` 31 | - the last decoder layer's attention weights of shape 32 | `(batch, tgt_len, src_len)` 33 | """ 34 | raise NotImplementedError 35 | 36 | def get_normalized_probs(self, net_output, log_probs, sample): 37 | """Get normalized probabilities (or log probs) from a net's output.""" 38 | 39 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 40 | assert sample is not None and 'target' in sample 41 | out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target']) 42 | return out.exp_() if not log_probs else out 43 | 44 | logits = net_output[0].float() 45 | 46 | if log_probs: 47 | return F.log_softmax(logits, dim=-1) 48 | else: 49 | return F.softmax(logits, dim=-1) 50 | 51 | def get_re_normalized_probs(self, net_output, log_probs, sample): 52 | """Get normalized probabilities (or log probs) from a net's output.""" 53 | 54 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 55 | assert sample is not None and 'target' in sample 56 | out = self.adaptive_softmax.get_log_prob(net_output[-1], sample['target']) 57 | return out.exp_() if not log_probs else out 58 | 59 | logits = net_output[-1].float() 60 | # import pdb 61 | # pdb.set_trace() 62 | # path_2 = "/home/hfyu/fairseq-attn-inverse/attn_scores/re_lprobs" 63 | # s_lprobs = logits.cpu().clone() 64 | # torch.save(s_lprobs, path_2) 65 | 66 | # if log_probs: 67 | # return F.log_softmax(logits, dim=-1) 68 | # else: 69 | return torch.log(F.softmin(logits, dim=-1)) 70 | 71 | def max_positions(self): 72 | """Maximum input length supported by the decoder.""" 73 | return 1e6 # an arbitrary large number 74 | 75 | def upgrade_state_dict(self, state_dict): 76 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 77 | return state_dict 78 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class FairseqEncoder(nn.Module): 12 | """Base class for encoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__() 16 | self.dictionary = dictionary 17 | 18 | def forward(self, src_tokens, src_lengths): 19 | """ 20 | Args: 21 | src_tokens (LongTensor): tokens in the source language of shape 22 | `(batch, src_len)` 23 | src_lengths (LongTensor): lengths of each source sentence of shape 24 | `(batch)` 25 | """ 26 | raise NotImplementedError 27 | 28 | def reorder_encoder_out(self, encoder_out, new_order): 29 | """ 30 | Reorder encoder output according to `new_order`. 31 | 32 | Args: 33 | encoder_out: output from the ``forward()`` method 34 | new_order (LongTensor): desired order 35 | 36 | Returns: 37 | `encoder_out` rearranged according to `new_order` 38 | """ 39 | raise NotImplementedError 40 | 41 | def max_positions(self): 42 | """Maximum input length supported by the encoder.""" 43 | return 1e6 # an arbitrary large number 44 | 45 | def upgrade_state_dict(self, state_dict): 46 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 47 | return state_dict 48 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_incremental_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqDecoder 9 | 10 | 11 | class FairseqIncrementalDecoder(FairseqDecoder): 12 | """Base class for incremental decoders. 13 | 14 | Incremental decoding is a special mode at inference time where the Model 15 | only receives a single timestep of input corresponding to the immediately 16 | previous output token (for input feeding) and must produce the next output 17 | *incrementally*. Thus the model must cache any long-term state that is 18 | needed about the sequence, e.g., hidden states, convolutional states, etc. 19 | 20 | Compared to the standard :class:`FairseqDecoder` interface, the incremental 21 | decoder interface allows :func:`forward` functions to take an extra keyword 22 | argument (*incremental_state*) that can be used to cache state across 23 | time-steps. 24 | 25 | The :class:`FairseqIncrementalDecoder` interface also defines the 26 | :func:`reorder_incremental_state` method, which is used during beam search 27 | to select and reorder the incremental state based on the selection of beams. 28 | """ 29 | 30 | def __init__(self, dictionary): 31 | super().__init__(dictionary) 32 | 33 | def forward(self, prev_output_tokens, encoder_out,model_type=None, incremental_state=None): 34 | """ 35 | Args: 36 | prev_output_tokens (LongTensor): previous decoder outputs of shape 37 | `(batch, tgt_len)`, for input feeding/teacher forcing 38 | encoder_out (Tensor, optional): output from the encoder, used for 39 | encoder-side attention 40 | incremental_state (dict): dictionary used for storing state during 41 | :ref:`Incremental decoding` 42 | 43 | Returns: 44 | tuple: 45 | - the last decoder layer's output of shape `(batch, tgt_len, 46 | vocab)` 47 | - the last decoder layer's attention weights of shape `(batch, 48 | tgt_len, src_len)` 49 | """ 50 | raise NotImplementedError 51 | 52 | def reorder_incremental_state(self, incremental_state, new_order): 53 | """Reorder incremental state. 54 | 55 | This should be called when the order of the input has changed from the 56 | previous time step. A typical use case is beam search, where the input 57 | order changes between time steps based on the selection of beams. 58 | """ 59 | def apply_reorder_incremental_state(module): 60 | if module != self and hasattr(module, 'reorder_incremental_state'): 61 | module.reorder_incremental_state( 62 | incremental_state, 63 | new_order, 64 | ) 65 | self.apply(apply_reorder_incremental_state) 66 | 67 | def set_beam_size(self, beam_size): 68 | """Sets the beam size in the decoder and all children.""" 69 | if getattr(self, '_beam_size', -1) != beam_size: 70 | def apply_set_beam_size(module): 71 | if module != self and hasattr(module, 'set_beam_size'): 72 | module.set_beam_size(beam_size) 73 | self.apply(apply_set_beam_size) 74 | self._beam_size = beam_size 75 | -------------------------------------------------------------------------------- /fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .adaptive_softmax import AdaptiveSoftmax 9 | from .beamable_mm import BeamableMM 10 | from .character_token_embedder import CharacterTokenEmbedder 11 | from .conv_tbc import ConvTBC 12 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 13 | from .grad_multiply import GradMultiply 14 | from .highway import Highway 15 | from .learned_positional_embedding import LearnedPositionalEmbedding 16 | from .linearized_convolution import LinearizedConvolution 17 | from .multihead_attention import MultiheadAttention 18 | from .scalar_bias import ScalarBias 19 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 20 | 21 | __all__ = [ 22 | 'AdaptiveSoftmax', 23 | 'BeamableMM', 24 | 'CharacterTokenEmbedder', 25 | 'ConvTBC', 26 | 'DownsampledMultiHeadAttention', 27 | 'GradMultiply', 28 | 'Highway', 29 | 'LearnedPositionalEmbedding', 30 | 'LinearizedConvolution', 31 | 'MultiheadAttention', 32 | 'ScalarBias', 33 | 'SinusoidalPositionalEmbedding', 34 | ] 35 | -------------------------------------------------------------------------------- /fairseq/modules/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | 14 | class AdaptiveSoftmax(nn.Module): 15 | """ 16 | This is an implementation of the efficient softmax approximation for 17 | graphical processing units (GPU), described in the paper "Efficient softmax 18 | approximation for GPUs" (http://arxiv.org/abs/1609.04309). 19 | """ 20 | 21 | def __init__(self, vocab_size, input_dim, cutoff, dropout): 22 | super().__init__() 23 | 24 | if vocab_size > cutoff[-1]: 25 | cutoff = cutoff + [vocab_size] 26 | else: 27 | assert vocab_size == cutoff[ 28 | -1], 'cannot specify cutoff larger than vocab size' 29 | 30 | output_dim = cutoff[0] + len(cutoff) - 1 31 | 32 | self.vocab_size = vocab_size 33 | self.cutoff = cutoff 34 | self.dropout = dropout 35 | self.input_dim = input_dim 36 | 37 | self.lsm = nn.LogSoftmax(dim=1) 38 | self.head = nn.Linear(input_dim, output_dim, bias=False) 39 | self._make_tail(True) 40 | 41 | def init_weights(m): 42 | if hasattr(m, 'weight'): 43 | nn.init.xavier_uniform_(m.weight) 44 | 45 | self.apply(init_weights) 46 | 47 | self.register_buffer('version', torch.LongTensor([1])) 48 | # versions prior to 1 had a bug that offset indices on the head by 1 49 | self.buggy_offset = 0 50 | 51 | def _make_tail(self, fix_exponent): 52 | extra_denom = 1 if fix_exponent else 0 53 | 54 | self.tail = nn.ModuleList() 55 | for i in range(len(self.cutoff) - 1): 56 | self.tail.append( 57 | nn.Sequential( 58 | nn.Linear(self.input_dim, self.input_dim // 4 ** (i + extra_denom), bias=False), 59 | nn.Dropout(self.dropout), 60 | nn.Linear(self.input_dim // 4 ** (i + extra_denom), self.cutoff[i + 1] - self.cutoff[i], bias=False) 61 | ) 62 | ) 63 | 64 | def upgrade_state_dict_named(self, state_dict, name): 65 | version_name = name + '.version' 66 | if version_name not in state_dict: 67 | self.buggy_offset = 1 68 | self._make_tail(False) 69 | state_dict[version_name] = torch.LongTensor([1]) 70 | 71 | def adapt_target(self, target): 72 | """ 73 | In order to be efficient, the AdaptiveSoftMax does not compute the 74 | scores for all the word of the vocabulary for all the examples. It is 75 | thus necessary to call the method adapt_target of the AdaptiveSoftMax 76 | layer inside each forward pass. 77 | """ 78 | 79 | target = target.view(-1) 80 | new_target = [target.clone()] 81 | target_idxs = [] 82 | 83 | for i in range(len(self.cutoff) - 1): 84 | mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) 85 | new_target[0][mask] = self.cutoff[0] + i - self.buggy_offset 86 | 87 | if mask.any(): 88 | target_idxs.append(mask.nonzero().squeeze(1)) 89 | new_target.append(target[mask].add(-self.cutoff[i])) 90 | else: 91 | target_idxs.append(None) 92 | new_target.append(None) 93 | 94 | return new_target, target_idxs 95 | 96 | def forward(self, input, target): 97 | """ 98 | Args: 99 | input: (b x t x d) 100 | target: (b x t) 101 | Returns: 102 | 2 lists: output for each cutoff section and new targets by cut off 103 | """ 104 | 105 | input = input.contiguous().view(-1, input.size(-1)) 106 | input = F.dropout(input, p=self.dropout, training=self.training) 107 | 108 | new_target, target_idxs = self.adapt_target(target) 109 | output = [self.head(input)] 110 | 111 | for i in range(len(target_idxs)): 112 | if target_idxs[i] is not None: 113 | output.append(self.tail[i](input.index_select(0, target_idxs[i]))) 114 | else: 115 | output.append(None) 116 | 117 | return output, new_target 118 | 119 | def get_log_prob(self, input, target): 120 | """ 121 | Computes the log probabilities for all the words of the vocabulary, 122 | given a 2D tensor of hidden vectors. 123 | """ 124 | 125 | bsz, length, dim = input.size() 126 | input = input.contiguous().view(-1, dim) 127 | 128 | if target is not None: 129 | _, target_idxs = self.adapt_target(target) 130 | else: 131 | target_idxs = None 132 | 133 | head_y = self.head(input) 134 | log_probs = head_y.new_zeros(input.size(0), self.vocab_size) 135 | 136 | head_sz = self.cutoff[0] + len(self.tail) 137 | log_probs[:, :head_sz] = self.lsm(head_y) 138 | tail_priors = log_probs[:, self.cutoff[0] - self.buggy_offset: head_sz - self.buggy_offset].clone() 139 | 140 | for i in range(len(self.tail)): 141 | start = self.cutoff[i] 142 | end = self.cutoff[i + 1] 143 | 144 | if target_idxs is None: 145 | tail_out = log_probs[:, start:end] 146 | tail_out.copy_(self.tail[i](input)) 147 | log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None]) 148 | elif target_idxs[i] is not None: 149 | idxs = target_idxs[i] 150 | tail_out = log_probs[idxs, start:end] 151 | tail_out.copy_(self.tail[i](input[idxs])) 152 | log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None]) 153 | 154 | log_probs = log_probs.view(bsz, length, -1) 155 | return log_probs 156 | -------------------------------------------------------------------------------- /fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class BeamableMM(nn.Module): 13 | """This module provides an optimized MM for beam decoding with attention. 14 | 15 | It leverage the fact that the source-side of the input is replicated beam 16 | times and the target-side of the input is of width one. This layer speeds up 17 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 18 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 19 | """ 20 | def __init__(self, beam_size=None): 21 | super(BeamableMM, self).__init__() 22 | self.beam_size = beam_size 23 | 24 | def forward(self, input1, input2): 25 | if ( 26 | not self.training and # test mode 27 | self.beam_size is not None and # beam size is set 28 | input1.dim() == 3 and # only support batched input 29 | input1.size(1) == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /fairseq/modules/character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from torch import nn 13 | from torch.nn.utils.rnn import pad_sequence 14 | 15 | from typing import List, Tuple 16 | 17 | from .highway import Highway 18 | from fairseq.data import Dictionary 19 | 20 | 21 | class CharacterTokenEmbedder(torch.nn.Module): 22 | def __init__( 23 | self, 24 | vocab: Dictionary, 25 | filters: List[Tuple[int, int]], 26 | char_embed_dim: int, 27 | word_embed_dim: int, 28 | highway_layers: int, 29 | max_char_len: int = 50, 30 | ): 31 | super(CharacterTokenEmbedder, self).__init__() 32 | 33 | self.embedding_dim = word_embed_dim 34 | self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) 35 | self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) 36 | self.eos_idx, self.unk_idx = 0, 1 37 | 38 | self.convolutions = nn.ModuleList() 39 | for width, out_c in filters: 40 | self.convolutions.append( 41 | nn.Conv1d(char_embed_dim, out_c, kernel_size=width) 42 | ) 43 | 44 | final_dim = sum(f[1] for f in filters) 45 | 46 | self.highway = Highway(final_dim, highway_layers) 47 | self.projection = nn.Linear(final_dim, word_embed_dim) 48 | 49 | self.set_vocab(vocab, max_char_len) 50 | self.reset_parameters() 51 | 52 | def set_vocab(self, vocab, max_char_len): 53 | word_to_char = torch.LongTensor(len(vocab), max_char_len) 54 | 55 | truncated = 0 56 | for i in range(len(vocab)): 57 | if i < vocab.nspecial: 58 | char_idxs = [0] * max_char_len 59 | else: 60 | chars = vocab[i].encode() 61 | # +1 for padding 62 | char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars)) 63 | if len(char_idxs) > max_char_len: 64 | truncated += 1 65 | char_idxs = char_idxs[:max_char_len] 66 | word_to_char[i] = torch.LongTensor(char_idxs) 67 | 68 | if truncated > 0: 69 | print('Truncated {} words longer than {} characters'.format(truncated, max_char_len)) 70 | 71 | self.vocab = vocab 72 | self.word_to_char = word_to_char 73 | 74 | @property 75 | def padding_idx(self): 76 | return self.vocab.pad() 77 | 78 | def reset_parameters(self): 79 | nn.init.xavier_normal_(self.char_embeddings.weight) 80 | nn.init.xavier_normal_(self.symbol_embeddings) 81 | nn.init.xavier_normal_(self.projection.weight) 82 | nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.) 83 | nn.init.constant_(self.projection.bias, 0.) 84 | 85 | def forward( 86 | self, 87 | words: torch.Tensor, 88 | ): 89 | self.word_to_char = self.word_to_char.type_as(words) 90 | 91 | flat_words = words.view(-1) 92 | word_embs = self._convolve(self.word_to_char[flat_words]) 93 | 94 | pads = flat_words.eq(self.vocab.pad()) 95 | if pads.any(): 96 | word_embs[pads] = 0 97 | 98 | eos = flat_words.eq(self.vocab.eos()) 99 | if eos.any(): 100 | word_embs[eos] = self.symbol_embeddings[self.eos_idx] 101 | 102 | unk = flat_words.eq(self.vocab.unk()) 103 | if unk.any(): 104 | word_embs[unk] = self.symbol_embeddings[self.unk_idx] 105 | 106 | return word_embs.view(words.size() + (-1,)) 107 | 108 | def _convolve( 109 | self, 110 | char_idxs: torch.Tensor, 111 | ): 112 | char_embs = self.char_embeddings(char_idxs) 113 | char_embs = char_embs.transpose(1, 2) # BTC -> BCT 114 | 115 | conv_result = [] 116 | 117 | for i, conv in enumerate(self.convolutions): 118 | x = conv(char_embs) 119 | x, _ = torch.max(x, -1) 120 | x = F.relu(x) 121 | conv_result.append(x) 122 | 123 | conv_result = torch.cat(conv_result, dim=-1) 124 | conv_result = self.highway(conv_result) 125 | 126 | return self.projection(conv_result) 127 | -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch.nn.modules.utils import _single 10 | 11 | 12 | class ConvTBC(torch.nn.Module): 13 | """1D convolution over an input of shape (time x batch x channel) 14 | 15 | The implementation uses gemm to perform the convolution. This implementation 16 | is faster than cuDNN for small kernel sizes. 17 | """ 18 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 19 | super(ConvTBC, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _single(kernel_size) 23 | self.padding = _single(padding) 24 | 25 | self.weight = torch.nn.Parameter(torch.Tensor( 26 | self.kernel_size[0], in_channels, out_channels)) 27 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 28 | 29 | def forward(self, input): 30 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 31 | 32 | def __repr__(self): 33 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 34 | ', padding={padding}') 35 | if self.bias is None: 36 | s += ', bias=False' 37 | s += ')' 38 | return s.format(name=self.__class__.__name__, **self.__dict__) 39 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class GradMultiply(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, scale): 14 | ctx.scale = scale 15 | res = x.new(x) 16 | return res 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | return grad * ctx.scale, None 21 | -------------------------------------------------------------------------------- /fairseq/modules/highway.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from torch import nn 12 | 13 | 14 | class Highway(torch.nn.Module): 15 | """ 16 | A `Highway layer `_. 17 | Adopted from the AllenNLP implementation. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int, 23 | num_layers: int = 1 24 | ): 25 | super(Highway, self).__init__() 26 | self.input_dim = input_dim 27 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) 28 | for _ in range(num_layers)]) 29 | self.activation = nn.ReLU() 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | for layer in self.layers: 35 | # As per comment in AllenNLP: 36 | # We should bias the highway layer to just carry its input forward. We do that by 37 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 38 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 39 | # of the bias vector in each Linear layer. 40 | nn.init.constant_(layer.bias[self.input_dim:], 1) 41 | 42 | nn.init.constant_(layer.bias[:self.input_dim], 0) 43 | nn.init.xavier_normal_(layer.weight) 44 | 45 | def forward( 46 | self, 47 | x: torch.Tensor 48 | ): 49 | for layer in self.layers: 50 | projection = layer(x) 51 | proj_x, gate = projection.chunk(2, dim=-1) 52 | proj_x = self.activation(proj_x) 53 | gate = F.sigmoid(gate) 54 | x = gate * x + (1 - gate) * proj_x 55 | return x 56 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class LearnedPositionalEmbedding(nn.Embedding): 14 | """This module learns positional embeddings up to a fixed maximum size. 15 | 16 | Padding symbols are ignored, but it is necessary to specify whether padding 17 | is added on the left side (left_pad=True) or right side (left_pad=False). 18 | """ 19 | 20 | def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): 21 | super().__init__(num_embeddings, embedding_dim, padding_idx) 22 | self.left_pad = left_pad 23 | 24 | def forward(self, input, incremental_state=None): 25 | """Input is expected to be of size [bsz x seqlen].""" 26 | if incremental_state is not None: 27 | # positions is the same for every token when decoding a single step 28 | positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) 29 | else: 30 | positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) 31 | return super().forward(positions) 32 | 33 | def max_positions(self): 34 | """Maximum number of supported positions.""" 35 | return self.num_embeddings - self.padding_idx - 1 36 | -------------------------------------------------------------------------------- /fairseq/modules/linearized_convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from .conv_tbc import ConvTBC 14 | 15 | 16 | class LinearizedConvolution(ConvTBC): 17 | """An optimized version of nn.Conv1d. 18 | 19 | At training time, this module uses ConvTBC, which is an optimized version 20 | of Conv1d. At inference time, it optimizes incremental generation (i.e., 21 | one time step at a time) by replacing the convolutions with linear layers. 22 | Note that the input order changes from training to inference. 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 26 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 27 | self._linearized_weight = None 28 | self.register_backward_hook(self._clear_linearized_weight) 29 | 30 | def forward(self, input, incremental_state=None): 31 | """ 32 | Args: 33 | incremental_state: Used to buffer signal; if not None, then input is 34 | expected to contain a single frame. If the input order changes 35 | between time steps, call reorder_incremental_state. 36 | Input: 37 | Time x Batch x Channel during training 38 | Batch x Time x Channel during inference 39 | """ 40 | if incremental_state is None: 41 | output = super().forward(input) 42 | if self.kernel_size[0] > 1 and self.padding[0] > 0: 43 | # remove future timesteps added by padding 44 | output = output[:-self.padding[0], :, :] 45 | return output 46 | 47 | # reshape weight 48 | weight = self._get_linearized_weight() 49 | kw = self.kernel_size[0] 50 | 51 | bsz = input.size(0) # input: bsz x len x dim 52 | if kw > 1: 53 | input = input.data 54 | input_buffer = self._get_input_buffer(incremental_state) 55 | if input_buffer is None: 56 | input_buffer = input.new(bsz, kw, input.size(2)).zero_() 57 | self._set_input_buffer(incremental_state, input_buffer) 58 | else: 59 | # shift buffer 60 | input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() 61 | # append next input 62 | input_buffer[:, -1, :] = input[:, -1, :] 63 | input = input_buffer 64 | with torch.no_grad(): 65 | output = F.linear(input.view(bsz, -1), weight, self.bias) 66 | return output.view(bsz, 1, -1) 67 | 68 | def reorder_incremental_state(self, incremental_state, new_order): 69 | input_buffer = self._get_input_buffer(incremental_state) 70 | if input_buffer is not None: 71 | input_buffer = input_buffer.index_select(0, new_order) 72 | self._set_input_buffer(incremental_state, input_buffer) 73 | 74 | def _get_input_buffer(self, incremental_state): 75 | return utils.get_incremental_state(self, incremental_state, 'input_buffer') 76 | 77 | def _set_input_buffer(self, incremental_state, new_buffer): 78 | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) 79 | 80 | def _get_linearized_weight(self): 81 | if self._linearized_weight is None: 82 | kw = self.kernel_size[0] 83 | weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() 84 | assert weight.size() == (self.out_channels, kw, self.in_channels) 85 | self._linearized_weight = weight.view(self.out_channels, -1) 86 | return self._linearized_weight 87 | 88 | def _clear_linearized_weight(self, *args): 89 | self._linearized_weight = None 90 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | import torch 10 | 11 | 12 | class ScalarBias(torch.autograd.Function): 13 | """ 14 | Adds a vector of scalars, used in self-attention mechanism to allow 15 | the model to optionally attend to this vector instead of the past 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, input, dim, bias_init): 20 | size = list(input.size()) 21 | size[dim] += 1 22 | output = input.new(*size).fill_(bias_init) 23 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 24 | ctx.dim = dim 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad): 29 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 30 | 31 | 32 | def scalar_bias(input, dim, bias_init=0): 33 | return ScalarBias.apply(input, dim, bias_init) 34 | -------------------------------------------------------------------------------- /fairseq/modules/sinusoidal_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from fairseq import utils 14 | 15 | 16 | class SinusoidalPositionalEmbedding(nn.Module): 17 | """This module produces sinusoidal positional embeddings of any length. 18 | 19 | Padding symbols are ignored, but it is necessary to specify whether padding 20 | is added on the left side (left_pad=True) or right side (left_pad=False). 21 | """ 22 | 23 | def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024): 24 | super().__init__() 25 | self.embedding_dim = embedding_dim 26 | self.padding_idx = padding_idx 27 | self.left_pad = left_pad 28 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 29 | init_size, 30 | embedding_dim, 31 | padding_idx, 32 | ) 33 | self.onnx_trace = False 34 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 35 | 36 | def prepare_for_onnx_export_(self): 37 | self.onnx_trace = True 38 | 39 | @staticmethod 40 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 41 | """Build sinusoidal embeddings. 42 | 43 | This matches the implementation in tensor2tensor, but differs slightly 44 | from the description in Section 3.5 of "Attention Is All You Need". 45 | """ 46 | half_dim = embedding_dim // 2 47 | emb = math.log(10000) / (half_dim - 1) 48 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 49 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 50 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 51 | if embedding_dim % 2 == 1: 52 | # zero pad 53 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 54 | if padding_idx is not None: 55 | emb[padding_idx, :] = 0 56 | return emb 57 | 58 | def forward(self, input, incremental_state=None): 59 | """Input is expected to be of size [bsz x seqlen].""" 60 | # recompute/expand embeddings if needed 61 | bsz, seq_len = input.size() 62 | max_pos = self.padding_idx + 1 + seq_len 63 | if self.weights is None or max_pos > self.weights.size(0): 64 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 65 | max_pos, 66 | self.embedding_dim, 67 | self.padding_idx, 68 | ) 69 | self.weights = self.weights.type_as(self._float_tensor) 70 | 71 | if incremental_state is not None: 72 | # positions is the same for every token when decoding a single step 73 | return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) 74 | 75 | positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace) 76 | if self.onnx_trace: 77 | bsz = torch.onnx.operators.shape_as_tensor(input)[0] 78 | seq_len = torch.onnx.operators.shape_as_tensor(input)[1] 79 | flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) 80 | embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1]))) 81 | embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape) 82 | return embeddings 83 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 84 | 85 | def max_positions(self): 86 | """Maximum number of supported positions.""" 87 | return int(1e5) # an arbitrary large number 88 | -------------------------------------------------------------------------------- /fairseq/multiprocessing_pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import multiprocessing 9 | import os 10 | import pdb 11 | import sys 12 | 13 | 14 | class MultiprocessingPdb(pdb.Pdb): 15 | """A Pdb wrapper that works in a multiprocessing environment. 16 | 17 | Usage: `from fairseq import pdb; pdb.set_trace()` 18 | """ 19 | 20 | _stdin_fd = sys.stdin.fileno() 21 | _stdin = None 22 | _stdin_lock = multiprocessing.Lock() 23 | 24 | def __init__(self): 25 | pdb.Pdb.__init__(self, nosigint=True) 26 | 27 | def _cmdloop(self): 28 | stdin_bak = sys.stdin 29 | with self._stdin_lock: 30 | try: 31 | if not self._stdin: 32 | self._stdin = os.fdopen(self._stdin_fd) 33 | sys.stdin = self._stdin 34 | self.cmdloop() 35 | finally: 36 | sys.stdin = stdin_bak 37 | 38 | 39 | pdb = MultiprocessingPdb() 40 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_optimizer import FairseqOptimizer 12 | 13 | 14 | OPTIMIZER_REGISTRY = {} 15 | OPTIMIZER_CLASS_NAMES = set() 16 | 17 | 18 | def build_optimizer(args, params): 19 | params = filter(lambda p: p.requires_grad, params) 20 | return OPTIMIZER_REGISTRY[args.optimizer](args, params) 21 | 22 | 23 | def register_optimizer(name): 24 | """Decorator to register a new optimizer.""" 25 | 26 | def register_optimizer_cls(cls): 27 | if name in OPTIMIZER_REGISTRY: 28 | raise ValueError('Cannot register duplicate optimizer ({})'.format(name)) 29 | if not issubclass(cls, FairseqOptimizer): 30 | raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__)) 31 | if cls.__name__ in OPTIMIZER_CLASS_NAMES: 32 | # We use the optimizer class name as a unique identifier in 33 | # checkpoints, so all optimizer must have unique class names. 34 | raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__)) 35 | OPTIMIZER_REGISTRY[name] = cls 36 | OPTIMIZER_CLASS_NAMES.add(cls.__name__) 37 | return cls 38 | 39 | return register_optimizer_cls 40 | 41 | 42 | # automatically import any Python files in the optim/ directory 43 | for file in os.listdir(os.path.dirname(__file__)): 44 | if file.endswith('.py') and not file.startswith('_'): 45 | module = file[:file.find('.py')] 46 | importlib.import_module('fairseq.optim.' + module) 47 | -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adagrad') 14 | class Adagrad(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'weight_decay': self.args.weight_decay, 30 | } 31 | -------------------------------------------------------------------------------- /fairseq/optim/adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch 10 | import torch.optim 11 | 12 | from . import FairseqOptimizer, register_optimizer 13 | 14 | 15 | @register_optimizer('adam') 16 | class FairseqAdam(FairseqOptimizer): 17 | def __init__(self, args, params): 18 | super().__init__(args, params) 19 | self._optimizer = Adam(params, **self.optimizer_config) 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add optimizer-specific arguments to the parser.""" 24 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', 25 | help='betas for Adam optimizer') 26 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', 27 | help='epsilon for Adam optimizer') 28 | 29 | @property 30 | def optimizer_config(self): 31 | """ 32 | Return a kwarg dictionary that will be used to override optimizer 33 | args stored in checkpoints. This allows us to load a checkpoint and 34 | resume training using a different set of optimizer args, e.g., with a 35 | different learning rate. 36 | """ 37 | return { 38 | 'lr': self.args.lr[0], 39 | 'betas': eval(self.args.adam_betas), 40 | 'eps': self.args.adam_eps, 41 | 'weight_decay': self.args.weight_decay, 42 | } 43 | 44 | 45 | class Adam(torch.optim.Optimizer): 46 | """Implements Adam algorithm. 47 | 48 | This implementation is modified from torch.optim.Adam based on: 49 | `Fixed Weight Decay Regularization in Adam` 50 | (see https://arxiv.org/abs/1711.05101) 51 | 52 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 53 | 54 | Arguments: 55 | params (iterable): iterable of parameters to optimize or dicts defining 56 | parameter groups 57 | lr (float, optional): learning rate (default: 1e-3) 58 | betas (Tuple[float, float], optional): coefficients used for computing 59 | running averages of gradient and its square (default: (0.9, 0.999)) 60 | eps (float, optional): term added to the denominator to improve 61 | numerical stability (default: 1e-8) 62 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 63 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 64 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 65 | 66 | .. _Adam\: A Method for Stochastic Optimization: 67 | https://arxiv.org/abs/1412.6980 68 | .. _On the Convergence of Adam and Beyond: 69 | https://openreview.net/forum?id=ryQu7f-RZ 70 | """ 71 | 72 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 73 | weight_decay=0, amsgrad=False): 74 | defaults = dict(lr=lr, betas=betas, eps=eps, 75 | weight_decay=weight_decay, amsgrad=amsgrad) 76 | super(Adam, self).__init__(params, defaults) 77 | 78 | def step(self, closure=None): 79 | """Performs a single optimization step. 80 | 81 | Arguments: 82 | closure (callable, optional): A closure that reevaluates the model 83 | and returns the loss. 84 | """ 85 | loss = None 86 | if closure is not None: 87 | loss = closure() 88 | 89 | for group in self.param_groups: 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | grad = p.grad.data 94 | if grad.is_sparse: 95 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 96 | amsgrad = group['amsgrad'] 97 | 98 | state = self.state[p] 99 | 100 | # State initialization 101 | if len(state) == 0: 102 | state['step'] = 0 103 | # Exponential moving average of gradient values 104 | state['exp_avg'] = torch.zeros_like(p.data) 105 | # Exponential moving average of squared gradient values 106 | state['exp_avg_sq'] = torch.zeros_like(p.data) 107 | if amsgrad: 108 | # Maintains max of all exp. moving avg. of sq. grad. values 109 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 110 | 111 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 112 | if amsgrad: 113 | max_exp_avg_sq = state['max_exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | state['step'] += 1 117 | 118 | # Decay the first and second moment running average coefficient 119 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 120 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 121 | if amsgrad: 122 | # Maintains the maximum of all 2nd moment running avg. till now 123 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 124 | # Use the max. for normalizing running avg. of gradient 125 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 126 | else: 127 | denom = exp_avg_sq.sqrt().add_(group['eps']) 128 | 129 | bias_correction1 = 1 - beta1 ** state['step'] 130 | bias_correction2 = 1 - beta2 ** state['step'] 131 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 132 | 133 | if group['weight_decay'] != 0: 134 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 135 | 136 | p.data.addcdiv_(-step_size, exp_avg, denom) 137 | 138 | return loss 139 | -------------------------------------------------------------------------------- /fairseq/optim/fairseq_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | 11 | class FairseqOptimizer(object): 12 | 13 | def __init__(self, args, params): 14 | super().__init__() 15 | self.args = args 16 | self.params = params 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add optimizer-specific arguments to the parser.""" 21 | pass 22 | 23 | @property 24 | def optimizer(self): 25 | """Return a torch.optim.optimizer.Optimizer instance.""" 26 | if not hasattr(self, '_optimizer'): 27 | raise NotImplementedError 28 | if not isinstance(self._optimizer, torch.optim.Optimizer): 29 | raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') 30 | return self._optimizer 31 | 32 | @property 33 | def optimizer_config(self): 34 | """ 35 | Return a kwarg dictionary that will be used to override optimizer 36 | args stored in checkpoints. This allows us to load a checkpoint and 37 | resume training using a different set of optimizer args, e.g., with a 38 | different learning rate. 39 | """ 40 | raise NotImplementedError 41 | 42 | def get_lr(self): 43 | """Return the current learning rate.""" 44 | return self.optimizer.param_groups[0]['lr'] 45 | 46 | def set_lr(self, lr): 47 | """Set the learning rate.""" 48 | for param_group in self.optimizer.param_groups: 49 | param_group['lr'] = lr 50 | 51 | def state_dict(self): 52 | """Return the optimizer's state dict.""" 53 | return self.optimizer.state_dict() 54 | 55 | def load_state_dict(self, state_dict, optimizer_overrides=None): 56 | """Load an optimizer state dict. 57 | 58 | In general we should prefer the configuration of the existing optimizer 59 | instance (e.g., learning rate) over that found in the state_dict. This 60 | allows us to resume training from a checkpoint using a new set of 61 | optimizer args. 62 | """ 63 | self.optimizer.load_state_dict(state_dict) 64 | 65 | if optimizer_overrides is not None and len(optimizer_overrides) > 0: 66 | # override learning rate, momentum, etc. with latest values 67 | for group in self.optimizer.param_groups: 68 | group.update(optimizer_overrides) 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step.""" 72 | return self.optimizer.step(closure) 73 | 74 | def zero_grad(self): 75 | """Clears the gradients of all optimized parameters.""" 76 | return self.optimizer.zero_grad() 77 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_lr_scheduler import FairseqLRScheduler 12 | 13 | 14 | LR_SCHEDULER_REGISTRY = {} 15 | 16 | 17 | def build_lr_scheduler(args, optimizer): 18 | return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer) 19 | 20 | 21 | def register_lr_scheduler(name): 22 | """Decorator to register a new LR scheduler.""" 23 | 24 | def register_lr_scheduler_cls(cls): 25 | if name in LR_SCHEDULER_REGISTRY: 26 | raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name)) 27 | if not issubclass(cls, FairseqLRScheduler): 28 | raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__)) 29 | LR_SCHEDULER_REGISTRY[name] = cls 30 | return cls 31 | 32 | return register_lr_scheduler_cls 33 | 34 | 35 | # automatically import any Python files in the optim/lr_scheduler/ directory 36 | for file in os.listdir(os.path.dirname(__file__)): 37 | if file.endswith('.py') and not file.startswith('_'): 38 | module = file[:file.find('.py')] 39 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 40 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/cosine_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('cosine') 14 | class CosineSchedule(FairseqLRScheduler): 15 | """Assign LR based on a cyclical schedule that follows the cosine function. 16 | See https://arxiv.org/pdf/1608.03983.pdf for details 17 | We also support a warmup phase where we linearly increase the learning rate 18 | from some initial learning rate (`--warmup-init-lr`) until the configured 19 | learning rate (`--lr`). 20 | During warmup: 21 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 22 | lr = lrs[update_num] 23 | After warmup: 24 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) 25 | where 26 | t_curr is current percentage of updates within the current period range 27 | t_i is the current period range, which is scaled by t_mul after every iteration 28 | """ 29 | 30 | def __init__(self, args, optimizer): 31 | super().__init__(args, optimizer) 32 | if len(args.lr) > 1: 33 | raise ValueError( 34 | 'Cannot use a fixed learning rate schedule with cosine.' 35 | ' Consider --lr-scheduler=fixed instead.' 36 | ) 37 | 38 | warmup_end_lr = args.max_lr 39 | if args.warmup_init_lr < 0: 40 | args.warmup_init_lr = args.lr[0] 41 | 42 | self.min_lr = args.lr[0] 43 | self.max_lr = args.max_lr 44 | 45 | assert self.max_lr > self.min_lr, 'max_lr must be more than lr' 46 | 47 | self.t_mult = args.t_mult 48 | self.period = args.lr_period_updates 49 | 50 | if args.warmup_updates > 0: 51 | # linearly warmup for the first args.warmup_updates 52 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 53 | else: 54 | self.lr_step = 1 55 | 56 | self.warmup_updates = args.warmup_updates 57 | self.lr_shrink = args.lr_shrink 58 | 59 | # initial learning rate 60 | self.lr = args.warmup_init_lr 61 | self.optimizer.set_lr(self.lr) 62 | 63 | @staticmethod 64 | def add_args(parser): 65 | """Add arguments to the parser for this LR scheduler.""" 66 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 67 | help='warmup the learning rate linearly for the first N updates') 68 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 69 | help='initial learning rate during warmup phase; default is args.lr') 70 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 71 | help='max learning rate, must be more than args.lr') 72 | parser.add_argument('--t-mult', default=1, type=float, metavar='LR', 73 | help='factor to grow the length of each period') 74 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 75 | help='initial number of updates per period') 76 | 77 | def step(self, epoch, val_loss=None): 78 | """Update the learning rate at the end of the given epoch.""" 79 | super().step(epoch, val_loss) 80 | # we don't change the learning rate at epoch boundaries 81 | return self.optimizer.get_lr() 82 | 83 | def step_update(self, num_updates): 84 | """Update the learning rate after each update.""" 85 | if num_updates < self.args.warmup_updates: 86 | self.lr = self.args.warmup_init_lr + num_updates * self.lr_step 87 | else: 88 | curr_updates = num_updates - self.args.warmup_updates 89 | if self.t_mult != 1: 90 | i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult)) 91 | t_i = self.t_mult ** i * self.period 92 | t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period 93 | else: 94 | i = math.floor(curr_updates / self.period) 95 | t_i = self.period 96 | t_curr = curr_updates - (self.period * i) 97 | 98 | lr_shrink = self.lr_shrink ** i 99 | min_lr = self.min_lr * lr_shrink 100 | max_lr = self.max_lr * lr_shrink 101 | 102 | self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 103 | 104 | self.optimizer.set_lr(self.lr) 105 | return self.lr -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .. import FairseqOptimizer 9 | 10 | 11 | class FairseqLRScheduler(object): 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__() 15 | if not isinstance(optimizer, FairseqOptimizer): 16 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 17 | self.args = args 18 | self.optimizer = optimizer 19 | self.best = None 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add arguments to the parser for this LR scheduler.""" 24 | pass 25 | 26 | def state_dict(self): 27 | """Return the LR scheduler state dict.""" 28 | return {'best': self.best} 29 | 30 | def load_state_dict(self, state_dict): 31 | """Load an LR scheduler state dict.""" 32 | self.best = state_dict['best'] 33 | 34 | def step(self, epoch, val_loss=None): 35 | """Update the learning rate at the end of the given epoch.""" 36 | if val_loss is not None: 37 | if self.best is None: 38 | self.best = val_loss 39 | else: 40 | self.best = min(self.best, val_loss) 41 | 42 | def step_update(self, num_updates): 43 | """Update the learning rate after each update.""" 44 | return self.optimizer.get_lr() 45 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('fixed') 12 | class FixedSchedule(FairseqLRScheduler): 13 | """Decay the LR on a fixed schedule.""" 14 | 15 | def __init__(self, args, optimizer): 16 | super().__init__(args, optimizer) 17 | 18 | # set defaults 19 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 20 | 21 | self.lr = args.lr[0] 22 | if args.warmup_updates > 0: 23 | self.warmup_factor = 1. / args.warmup_updates 24 | else: 25 | self.warmup_factor = 1 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | """Add arguments to the parser for this LR scheduler.""" 30 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 31 | help='force annealing at specified epoch') 32 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 33 | help='warmup the learning rate linearly for the first N updates') 34 | 35 | def get_next_lr(self, epoch): 36 | lrs = self.args.lr 37 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 38 | # use fixed LR schedule 39 | next_lr = lrs[min(epoch, len(lrs) - 1)] 40 | else: 41 | # annneal based on lr_shrink 42 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 43 | return next_lr 44 | 45 | def step(self, epoch, val_loss=None): 46 | """Update the learning rate at the end of the given epoch.""" 47 | super().step(epoch, val_loss) 48 | self.lr = self.get_next_lr(epoch) 49 | self.optimizer.set_lr(self.warmup_factor * self.lr) 50 | return self.optimizer.get_lr() 51 | 52 | def step_update(self, num_updates): 53 | """Update the learning rate after each update.""" 54 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 55 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 56 | self.optimizer.set_lr(self.warmup_factor * self.lr) 57 | return self.optimizer.get_lr() 58 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/inverse_square_root_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('inverse_sqrt') 12 | class InverseSquareRootSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (`--warmup-init-lr`) until the configured 17 | learning rate (`--lr`). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup: 26 | 27 | lr = decay_factor / sqrt(update_num) 28 | 29 | where 30 | 31 | decay_factor = args.lr * sqrt(args.warmup_updates) 32 | """ 33 | 34 | def __init__(self, args, optimizer): 35 | super().__init__(args, optimizer) 36 | if len(args.lr) > 1: 37 | raise ValueError( 38 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 39 | ' Consider --lr-scheduler=fixed instead.' 40 | ) 41 | warmup_end_lr = args.lr[0] 42 | if args.warmup_init_lr < 0: 43 | args.warmup_init_lr = warmup_end_lr 44 | 45 | # linearly warmup for the first args.warmup_updates 46 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 47 | 48 | # then, decay prop. to the inverse square root of the update number 49 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 50 | 51 | # initial learning rate 52 | self.lr = args.warmup_init_lr 53 | 54 | self.optimizer.set_lr(self.lr) 55 | 56 | @staticmethod 57 | def add_args(parser): 58 | """Add arguments to the parser for this LR scheduler.""" 59 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 60 | help='warmup the learning rate linearly for the first N updates') 61 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 62 | help='initial learning rate during warmup phase; default is args.lr') 63 | 64 | def step(self, epoch, val_loss=None): 65 | """Update the learning rate at the end of the given epoch.""" 66 | super().step(epoch, val_loss) 67 | # we don't change the learning rate at epoch boundaries 68 | return self.optimizer.get_lr() 69 | 70 | def step_update(self, num_updates): 71 | """Update the learning rate after each update.""" 72 | if num_updates < self.args.warmup_updates: 73 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 74 | else: 75 | self.lr = self.decay_factor * num_updates**-0.5 76 | 77 | self.optimizer.set_lr(self.lr) 78 | return self.lr 79 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim.lr_scheduler 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('reduce_lr_on_plateau') 14 | class ReduceLROnPlateau(FairseqLRScheduler): 15 | """Decay the LR by a factor every time the validation loss plateaus.""" 16 | 17 | def __init__(self, args, optimizer): 18 | super().__init__(args, optimizer) 19 | if len(args.lr) > 1: 20 | raise ValueError( 21 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' 22 | ' Consider --lr-scheduler=fixed instead.' 23 | ) 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer.optimizer, patience=0, factor=args.lr_shrink) 26 | 27 | def state_dict(self): 28 | """Return the LR scheduler state dict.""" 29 | return { 30 | 'best': self.lr_scheduler.best, 31 | 'last_epoch': self.lr_scheduler.last_epoch, 32 | } 33 | 34 | def load_state_dict(self, state_dict): 35 | """Load an LR scheduler state dict.""" 36 | self.lr_scheduler.best = state_dict['best'] 37 | if 'last_epoch' in state_dict: 38 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 39 | 40 | def step(self, epoch, val_loss=None): 41 | """Update the learning rate at the end of the given epoch.""" 42 | if val_loss is not None: 43 | self.lr_scheduler.step(val_loss, epoch) 44 | else: 45 | self.lr_scheduler.last_epoch = epoch 46 | return self.optimizer.get_lr() 47 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/triangular_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('triangular') 14 | class TriangularSchedule(FairseqLRScheduler): 15 | """Assign LR based on a triangular cyclical schedule. 16 | 17 | See https://arxiv.org/pdf/1506.01186.pdf for details 18 | 19 | """ 20 | 21 | def __init__(self, args, optimizer): 22 | super().__init__(args, optimizer) 23 | if len(args.lr) > 1: 24 | raise ValueError( 25 | 'Cannot use a fixed learning rate schedule with triangular.' 26 | ' Consider --lr-scheduler=fixed instead.' 27 | ) 28 | 29 | lr = args.lr[0] 30 | 31 | assert args.max_lr > lr, 'max_lr must be more than lr' 32 | self.min_lr = lr 33 | self.max_lr = args.max_lr 34 | self.stepsize = args.lr_period_updates // 2 35 | self.lr_shrink = args.lr_shrink 36 | self.shrink_min = args.shrink_min 37 | 38 | # initial learning rate 39 | self.lr = self.min_lr 40 | self.optimizer.set_lr(self.lr) 41 | 42 | @staticmethod 43 | def add_args(parser): 44 | """Add arguments to the parser for this LR scheduler.""" 45 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 46 | help='max learning rate, must be more than args.lr') 47 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 48 | help='initial number of updates per period (cycle length)') 49 | parser.add_argument('--shrink-min', action='store_true', 50 | help='if set, also shrinks min lr') 51 | 52 | def step(self, epoch, val_loss=None): 53 | """Update the learning rate at the end of the given epoch.""" 54 | super().step(epoch, val_loss) 55 | # we don't change the learning rate at epoch boundaries 56 | return self.optimizer.get_lr() 57 | 58 | def step_update(self, num_updates): 59 | """Update the learning rate after each update.""" 60 | cycle = math.floor(num_updates / (2 * self.stepsize)) 61 | 62 | lr_shrink = self.lr_shrink ** cycle 63 | max_lr = self.max_lr * lr_shrink 64 | if self.shrink_min: 65 | min_lr = self.min_lr * lr_shrink 66 | else: 67 | min_lr = self.min_lr 68 | 69 | x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) 70 | self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) 71 | 72 | self.optimizer.set_lr(self.lr) 73 | return self.lr 74 | -------------------------------------------------------------------------------- /fairseq/optim/nag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.optim.optimizer import Optimizer, required 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('nag') 14 | class FairseqNAG(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = NAG(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | 33 | 34 | class NAG(Optimizer): 35 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 36 | defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) 37 | super(NAG, self).__init__(params, defaults) 38 | 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | weight_decay = group['weight_decay'] 52 | momentum = group['momentum'] 53 | lr = group['lr'] 54 | lr_old = group.get('lr_old', lr) 55 | lr_correct = lr / lr_old 56 | 57 | for p in group['params']: 58 | if p.grad is None: 59 | continue 60 | 61 | d_p = p.grad.data 62 | param_state = self.state[p] 63 | if 'momentum_buffer' not in param_state: 64 | param_state['momentum_buffer'] = d_p.clone().zero_() 65 | 66 | buf = param_state['momentum_buffer'] 67 | 68 | if weight_decay != 0: 69 | p.data.mul_(1 - lr * weight_decay) 70 | p.data.add_(momentum * momentum * lr_correct, buf) 71 | p.data.add_(-(1 + momentum) * lr, d_p) 72 | 73 | buf.mul_(momentum * lr_correct).add_(-lr, d_p) 74 | 75 | group['lr_old'] = lr 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('sgd') 14 | class SGD(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | -------------------------------------------------------------------------------- /fairseq/sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | from fairseq import utils 11 | 12 | 13 | class SequenceScorer(object): 14 | """Scores the target for a given source sentence.""" 15 | 16 | def __init__(self, models, tgt_dict): 17 | self.models = models 18 | self.pad = tgt_dict.pad() 19 | 20 | def cuda(self): 21 | for model in self.models: 22 | model.cuda() 23 | return self 24 | 25 | def score_batched_itr(self, data_itr, cuda=False, timer=None): 26 | """Iterate over a batched dataset and yield scored translations.""" 27 | for sample in data_itr: 28 | s = utils.move_to_cuda(sample) if cuda else sample 29 | if timer is not None: 30 | timer.start() 31 | pos_scores, attn = self.score(s) 32 | for i, id in enumerate(s['id'].data): 33 | # remove padding from ref 34 | src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad) 35 | ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None 36 | tgt_len = ref.numel() 37 | pos_scores_i = pos_scores[i][:tgt_len] 38 | score_i = pos_scores_i.sum() / tgt_len 39 | if attn is not None: 40 | attn_i = attn[i] 41 | _, alignment = attn_i.max(dim=0) 42 | else: 43 | attn_i = alignment = None 44 | hypos = [{ 45 | 'tokens': ref, 46 | 'score': score_i, 47 | 'attention': attn_i, 48 | 'alignment': alignment, 49 | 'positional_scores': pos_scores_i, 50 | }] 51 | if timer is not None: 52 | timer.stop(s['ntokens']) 53 | # return results in the same format as SequenceGenerator 54 | yield id, src, ref, hypos 55 | 56 | def score(self, sample): 57 | """Score a batch of translations.""" 58 | net_input = sample['net_input'] 59 | 60 | # compute scores for each model in the ensemble 61 | avg_probs = None 62 | avg_attn = None 63 | for model in self.models: 64 | with torch.no_grad(): 65 | model.eval() 66 | decoder_out = model.forward(**net_input) 67 | attn = decoder_out[1] 68 | 69 | probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data 70 | if avg_probs is None: 71 | avg_probs = probs 72 | else: 73 | avg_probs.add_(probs) 74 | if attn is not None: 75 | attn = attn.data 76 | if avg_attn is None: 77 | avg_attn = attn 78 | else: 79 | avg_attn.add_(attn) 80 | avg_probs.div_(len(self.models)) 81 | avg_probs.log_() 82 | if avg_attn is not None: 83 | avg_attn.div_(len(self.models)) 84 | avg_probs = avg_probs.gather( 85 | dim=2, 86 | index=sample['target'].data.unsqueeze(-1), 87 | ) 88 | return avg_probs.squeeze(2), avg_attn 89 | -------------------------------------------------------------------------------- /fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights 5 | # can be found in the PATENTS file in the same directory. 6 | 7 | import argparse 8 | import importlib 9 | import os 10 | 11 | from .fairseq_task import FairseqTask 12 | 13 | 14 | TASK_REGISTRY = {} 15 | TASK_CLASS_NAMES = set() 16 | 17 | 18 | def setup_task(args): 19 | return TASK_REGISTRY[args.task].setup_task(args) 20 | 21 | 22 | def register_task(name): 23 | """ 24 | New tasks can be added to fairseq with the 25 | :func:`~fairseq.tasks.register_task` function decorator. 26 | 27 | For example:: 28 | 29 | @register_task('classification') 30 | class ClassificationTask(FairseqTask): 31 | (...) 32 | 33 | .. note:: 34 | 35 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 36 | interface. 37 | 38 | Please see the 39 | 40 | Args: 41 | name (str): the name of the task 42 | """ 43 | 44 | def register_task_cls(cls): 45 | if name in TASK_REGISTRY: 46 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 47 | if not issubclass(cls, FairseqTask): 48 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 49 | if cls.__name__ in TASK_CLASS_NAMES: 50 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 51 | TASK_REGISTRY[name] = cls 52 | TASK_CLASS_NAMES.add(cls.__name__) 53 | return cls 54 | 55 | return register_task_cls 56 | 57 | 58 | # automatically import any Python files in the tasks/ directory 59 | for file in os.listdir(os.path.dirname(__file__)): 60 | if file.endswith('.py') and not file.startswith('_'): 61 | task_name = file[:file.find('.py')] 62 | importlib.import_module('fairseq.tasks.' + task_name) 63 | 64 | # expose `task_parser` for sphinx 65 | if task_name in TASK_REGISTRY: 66 | parser = argparse.ArgumentParser(add_help=False) 67 | group_task = parser.add_argument_group('Task name') 68 | group_task.add_argument( 69 | '--task', metavar=task_name, 70 | help='Enable this task with: ``--task=' + task_name + '``' 71 | ) 72 | group_args = parser.add_argument_group('Additional command-line arguments') 73 | TASK_REGISTRY[task_name].add_args(group_args) 74 | globals()[task_name + '_parser'] = parser 75 | -------------------------------------------------------------------------------- /fairseq/tasks/language_modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import itertools 9 | import numpy as np 10 | import os 11 | 12 | from torch.utils.data import ConcatDataset 13 | 14 | from fairseq.data import ( 15 | Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, 16 | MonolingualDataset, TokenBlockDataset, 17 | ) 18 | 19 | from . import FairseqTask, register_task 20 | 21 | 22 | @register_task('language_modeling') 23 | class LanguageModelingTask(FairseqTask): 24 | """ 25 | Train a language model. 26 | 27 | Args: 28 | dictionary (Dictionary): the dictionary for the language model 29 | 30 | .. note:: 31 | 32 | The language modeling task is compatible with :mod:`train.py `, 33 | :mod:`generate.py `, :mod:`interactive.py ` and 34 | :mod:`eval_lm.py `. 35 | 36 | The language modeling task provides the following additional command-line 37 | arguments: 38 | 39 | .. argparse:: 40 | :ref: fairseq.tasks.language_modeling_parser 41 | :prog: 42 | """ 43 | 44 | @staticmethod 45 | def add_args(parser): 46 | """Add task-specific arguments to the parser.""" 47 | parser.add_argument('data', help='path to data directory') 48 | parser.add_argument('--sample-break-mode', 49 | choices=['none', 'complete', 'eos'], 50 | help='If omitted or "none", fills each sample with tokens-per-sample ' 51 | 'tokens. If set to "complete", splits samples only at the end ' 52 | 'of sentence, but may include multiple sentences per sample. ' 53 | 'If set to "eos", includes only one sentence per sample.') 54 | parser.add_argument('--tokens-per-sample', default=1024, type=int, 55 | help='max number of tokens per sample for LM dataset') 56 | parser.add_argument('--raw-text', default=False, action='store_true', 57 | help='load raw text dataset') 58 | 59 | def __init__(self, args, dictionary): 60 | super().__init__(args) 61 | self.dictionary = dictionary 62 | 63 | @classmethod 64 | def setup_task(cls, args, **kwargs): 65 | """Setup the task (e.g., load dictionaries). 66 | 67 | Args: 68 | args (argparse.Namespace): parsed command-line arguments 69 | """ 70 | dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) 71 | print('| dictionary: {} types'.format(len(dictionary))) 72 | return cls(args, dictionary) 73 | 74 | def load_dataset(self, split, combine=False): 75 | """Load a given dataset split. 76 | 77 | Args: 78 | split (str): name of the split (e.g., train, valid, test) 79 | """ 80 | 81 | loaded_datasets = [] 82 | 83 | for k in itertools.count(): 84 | split_k = split + (str(k) if k > 0 else '') 85 | path = os.path.join(self.args.data, split_k) 86 | 87 | if self.args.raw_text and IndexedRawTextDataset.exists(path): 88 | ds = IndexedRawTextDataset(path, self.dictionary) 89 | tokens = [t for l in ds.tokens_list for t in l] 90 | elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): 91 | ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) 92 | tokens = ds.buffer 93 | else: 94 | if k > 0: 95 | break 96 | else: 97 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) 98 | 99 | loaded_datasets.append( 100 | TokenBlockDataset( 101 | tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode, 102 | include_targets=True 103 | )) 104 | 105 | print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) 106 | 107 | if not combine: 108 | break 109 | 110 | if len(loaded_datasets) == 1: 111 | dataset = loaded_datasets[0] 112 | sizes = dataset.sizes 113 | else: 114 | dataset = ConcatDataset(loaded_datasets) 115 | sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) 116 | 117 | self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False) 118 | 119 | @property 120 | def target_dictionary(self): 121 | """Return the :class:`~fairseq.data.Dictionary` for the language 122 | model.""" 123 | return self.dictionary 124 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from collections import Counter 9 | import re 10 | 11 | import torch 12 | 13 | 14 | SPACE_NORMALIZER = re.compile("\s+") 15 | 16 | 17 | def tokenize_line(line): 18 | line = SPACE_NORMALIZER.sub(" ", line) 19 | line = line.strip() 20 | return line.split() 21 | 22 | 23 | class Tokenizer: 24 | 25 | @staticmethod 26 | def add_file_to_dictionary(filename, dict, tokenize): 27 | with open(filename, 'r') as f: 28 | for line in f: 29 | for word in tokenize(line): 30 | dict.add_symbol(word) 31 | dict.add_symbol(dict.eos_word) 32 | 33 | 34 | 35 | 36 | @staticmethod 37 | def binarize(filename, dict, consumer, tokenize=tokenize_line, 38 | append_eos=True, reverse_order=False): 39 | nseq, ntok = 0, 0 40 | replaced = Counter() 41 | 42 | def replaced_consumer(word, idx): 43 | if idx == dict.unk_index and word != dict.unk_word: 44 | replaced.update([word]) 45 | 46 | with open(filename, 'r') as f: 47 | for line in f: 48 | ids = Tokenizer.tokenize( 49 | line=line, 50 | dict=dict, 51 | tokenize=tokenize, 52 | add_if_not_exist=False, 53 | consumer=replaced_consumer, 54 | append_eos=append_eos, 55 | reverse_order=reverse_order, 56 | ) 57 | nseq += 1 58 | 59 | consumer(ids) 60 | ntok += len(ids) 61 | return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)} 62 | 63 | @staticmethod 64 | def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, 65 | consumer=None, append_eos=True, reverse_order=False): 66 | words = tokenize(line) 67 | if reverse_order: 68 | words = list(reversed(words)) 69 | nwords = len(words) 70 | ids = torch.IntTensor(nwords + 1 if append_eos else nwords) 71 | 72 | for i, word in enumerate(words): 73 | if add_if_not_exist: 74 | idx = dict.add_symbol(word) 75 | else: 76 | idx = dict.index(word) 77 | if consumer is not None: 78 | consumer(word, idx) 79 | ids[i] = idx 80 | if append_eos: 81 | ids[nwords] = dict.eos_index 82 | return ids 83 | -------------------------------------------------------------------------------- /multiprocessing_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import os 10 | import random 11 | import signal 12 | import torch 13 | 14 | from fairseq import distributed_utils, options 15 | 16 | from train import main as single_process_main 17 | 18 | 19 | def main(args): 20 | # Set distributed training parameters for a single node. 21 | args.distributed_world_size = torch.cuda.device_count() 22 | args.distributed_init_method = 'tcp://localhost:{port}'.format( 23 | port=random.randint(10000, 20000)) 24 | 25 | mp = torch.multiprocessing.get_context('spawn') 26 | 27 | # Create a thread to listen for errors in the child processes. 28 | error_queue = mp.SimpleQueue() 29 | error_handler = ErrorHandler(error_queue) 30 | 31 | # Train with multiprocessing. 32 | procs = [] 33 | for i in range(args.distributed_world_size): 34 | args.distributed_rank = i 35 | args.device_id = i 36 | procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True)) 37 | procs[i].start() 38 | error_handler.add_child(procs[i].pid) 39 | for p in procs: 40 | p.join() 41 | 42 | 43 | def run(args, error_queue): 44 | try: 45 | args.distributed_rank = distributed_utils.distributed_init(args) 46 | single_process_main(args) 47 | except KeyboardInterrupt: 48 | pass # killed by parent, do nothing 49 | except Exception: 50 | # propagate exception to parent process, keeping original traceback 51 | import traceback 52 | error_queue.put((args.distributed_rank, traceback.format_exc())) 53 | 54 | 55 | class ErrorHandler(object): 56 | """A class that listens for exceptions in children processes and propagates 57 | the tracebacks to the parent process.""" 58 | 59 | def __init__(self, error_queue): 60 | import signal 61 | import threading 62 | self.error_queue = error_queue 63 | self.children_pids = [] 64 | self.error_thread = threading.Thread(target=self.error_listener, daemon=True) 65 | self.error_thread.start() 66 | signal.signal(signal.SIGUSR1, self.signal_handler) 67 | 68 | def add_child(self, pid): 69 | self.children_pids.append(pid) 70 | 71 | def error_listener(self): 72 | (rank, original_trace) = self.error_queue.get() 73 | self.error_queue.put((rank, original_trace)) 74 | os.kill(os.getpid(), signal.SIGUSR1) 75 | 76 | def signal_handler(self, signalnum, stackframe): 77 | for pid in self.children_pids: 78 | os.kill(pid, signal.SIGINT) # kill children processes 79 | (rank, original_trace) = self.error_queue.get() 80 | msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n" 81 | msg += original_trace 82 | raise Exception(msg) 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = options.get_training_parser() 87 | args = options.parse_args_and_arch(parser) 88 | main(args) 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | numpy 3 | torch 4 | tqdm 5 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | BLEU scoring of generated translations against reference translations. 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | 16 | from fairseq import bleu, tokenizer 17 | from fairseq.data import dictionary 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') 22 | parser.add_argument('-s', '--sys', default='-', help='system output') 23 | parser.add_argument('-r', '--ref', required=True, help='references') 24 | parser.add_argument('-o', '--order', default=4, metavar='N', 25 | type=int, help='consider ngrams up to this order') 26 | parser.add_argument('--ignore-case', action='store_true', 27 | help='case-insensitive scoring') 28 | return parser 29 | 30 | 31 | def main(): 32 | parser = get_parser() 33 | args = parser.parse_args() 34 | print(args) 35 | 36 | assert args.sys == '-' or os.path.exists(args.sys), \ 37 | "System output file {} does not exist".format(args.sys) 38 | assert os.path.exists(args.ref), \ 39 | "Reference file {} does not exist".format(args.ref) 40 | 41 | dict = dictionary.Dictionary() 42 | 43 | def readlines(fd): 44 | for line in fd.readlines(): 45 | if args.ignore_case: 46 | yield line.lower() 47 | yield line 48 | 49 | def score(fdsys): 50 | with open(args.ref) as fdref: 51 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 52 | for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): 53 | sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) 54 | ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) 55 | scorer.add(ref_tok, sys_tok) 56 | print(scorer.result_string(args.order)) 57 | 58 | if args.sys == '-': 59 | score(sys.stdin) 60 | else: 61 | with open(args.sys, 'r') as f: 62 | score(f) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import collections 5 | import torch 6 | import os 7 | import re 8 | 9 | 10 | def average_checkpoints(inputs): 11 | """Loads checkpoints from inputs and returns a model with averaged weights. 12 | 13 | Args: 14 | inputs: An iterable of string paths of checkpoints to load from. 15 | 16 | Returns: 17 | A dict of string keys mapping to various values. The 'model' key 18 | from the returned dict should correspond to an OrderedDict mapping 19 | string parameter names to torch Tensors. 20 | """ 21 | params_dict = collections.OrderedDict() 22 | params_keys = None 23 | new_state = None 24 | for f in inputs: 25 | state = torch.load( 26 | f, 27 | map_location=( 28 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 29 | ), 30 | ) 31 | # Copies over the settings from the first checkpoint 32 | if new_state is None: 33 | new_state = state 34 | 35 | model_params = state['model'] 36 | 37 | model_params_keys = list(model_params.keys()) 38 | if params_keys is None: 39 | params_keys = model_params_keys 40 | elif params_keys != model_params_keys: 41 | raise KeyError( 42 | 'For checkpoint {}, expected list of params: {}, ' 43 | 'but found: {}'.format(f, params_keys, model_params_keys) 44 | ) 45 | 46 | for k in params_keys: 47 | if k not in params_dict: 48 | params_dict[k] = [] 49 | p = model_params[k] 50 | if isinstance(p, torch.HalfTensor): 51 | p = p.float() 52 | params_dict[k].append(p) 53 | 54 | averaged_params = collections.OrderedDict() 55 | # v should be a list of torch Tensor. 56 | for k, v in params_dict.items(): 57 | summed_v = None 58 | for x in v: 59 | summed_v = summed_v + x if summed_v is not None else x 60 | averaged_params[k] = summed_v / len(v) 61 | new_state['model'] = averaged_params 62 | return new_state 63 | 64 | 65 | def last_n_checkpoints(paths, n, update_based): 66 | assert len(paths) == 1 67 | path = paths[0] 68 | if update_based: 69 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 70 | else: 71 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 72 | files = os.listdir(path) 73 | 74 | entries = [] 75 | for f in files: 76 | m = pt_regexp.fullmatch(f) 77 | if m is not None: 78 | entries.append((int(m.group(1)), m.group(0))) 79 | if len(entries) < n: 80 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 81 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser( 86 | description='Tool to average the params of input checkpoints to ' 87 | 'produce a new checkpoint', 88 | ) 89 | 90 | parser.add_argument( 91 | '--inputs', 92 | required=True, 93 | nargs='+', 94 | help='Input checkpoint file paths.', 95 | ) 96 | parser.add_argument( 97 | '--output', 98 | required=True, 99 | metavar='FILE', 100 | help='Write the new checkpoint containing the averaged weights to this ' 101 | 'path.', 102 | ) 103 | num_group = parser.add_mutually_exclusive_group() 104 | num_group.add_argument( 105 | '--num-epoch-checkpoints', 106 | type=int, 107 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 108 | 'and average last this many of them.', 109 | ) 110 | num_group.add_argument( 111 | '--num-update-checkpoints', 112 | type=int, 113 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 114 | 'and average last this many of them.', 115 | ) 116 | args = parser.parse_args() 117 | print(args) 118 | 119 | num = None 120 | is_update_based = False 121 | if args.num_update_checkpoints is not None: 122 | num = args.num_update_checkpoints 123 | is_update_based = True 124 | elif args.num_epoch_checkpoints is not None: 125 | num = args.num_epoch_checkpoints 126 | 127 | if num is not None: 128 | args.inputs = last_n_checkpoints(args.inputs, num, is_update_based) 129 | print('averaging checkpoints: ', args.inputs) 130 | 131 | new_state = average_checkpoints(args.inputs) 132 | torch.save(new_state, args.output) 133 | print('Finished writing averaged checkpoint to {}.'.format(args.output)) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /scripts/build_sym_alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | """ 10 | Use this script in order to build symmetric alignments for your translation 11 | dataset. 12 | This script depends on fast_align and mosesdecoder tools. You will need to 13 | build those before running the script. 14 | fast_align: 15 | github: http://github.com/clab/fast_align 16 | instructions: follow the instructions in README.md 17 | mosesdecoder: 18 | github: http://github.com/moses-smt/mosesdecoder 19 | instructions: http://www.statmt.org/moses/?n=Development.GetStarted 20 | The script produces the following files under --output_dir: 21 | text.joined - concatenation of lines from the source_file and the 22 | target_file. 23 | align.forward - forward pass of fast_align. 24 | align.backward - backward pass of fast_align. 25 | aligned.sym_heuristic - symmetrized alignment. 26 | """ 27 | 28 | import argparse 29 | import os 30 | from itertools import zip_longest 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='symmetric alignment builer') 35 | parser.add_argument('--fast_align_dir', 36 | help='path to fast_align build directory') 37 | parser.add_argument('--mosesdecoder_dir', 38 | help='path to mosesdecoder root directory') 39 | parser.add_argument('--sym_heuristic', 40 | help='heuristic to use for symmetrization', 41 | default='grow-diag-final-and') 42 | parser.add_argument('--source_file', 43 | help='path to a file with sentences ' 44 | 'in the source language') 45 | parser.add_argument('--target_file', 46 | help='path to a file with sentences ' 47 | 'in the target language') 48 | parser.add_argument('--output_dir', 49 | help='output directory') 50 | args = parser.parse_args() 51 | 52 | fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') 53 | symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal') 54 | sym_fast_align_bin = os.path.join( 55 | args.mosesdecoder_dir, 'scripts', 'ems', 56 | 'support', 'symmetrize-fast-align.perl') 57 | 58 | # create joined file 59 | joined_file = os.path.join(args.output_dir, 'text.joined') 60 | with open(args.source_file, 'r') as src, open(args.target_file, 'r') as tgt: 61 | with open(joined_file, 'w') as joined: 62 | for s, t in zip_longest(src, tgt): 63 | print('{} ||| {}'.format(s.strip(), t.strip()), file=joined) 64 | 65 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 66 | 67 | # run forward alignment 68 | fwd_align_file = os.path.join(args.output_dir, 'align.forward') 69 | fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format( 70 | FASTALIGN=fast_align_bin, 71 | JOINED=joined_file, 72 | FWD=fwd_align_file) 73 | assert os.system(fwd_fast_align_cmd) == 0 74 | 75 | # run backward alignment 76 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 77 | bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format( 78 | FASTALIGN=fast_align_bin, 79 | JOINED=joined_file, 80 | BWD=bwd_align_file) 81 | assert os.system(bwd_fast_align_cmd) == 0 82 | 83 | # run symmetrization 84 | sym_out_file = os.path.join(args.output_dir, 'aligned') 85 | sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format( 86 | SYMFASTALIGN=sym_fast_align_bin, 87 | FWD=fwd_align_file, 88 | BWD=bwd_align_file, 89 | SRC=args.source_file, 90 | TGT=args.target_file, 91 | OUT=sym_out_file, 92 | HEURISTIC=args.sym_heuristic, 93 | SYMAL=symal_bin 94 | ) 95 | assert os.system(sym_cmd) == 0 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_dictionary.lua 9 | require 'fairseq' 10 | require 'torch' 11 | require 'paths' 12 | 13 | if #arg < 1 then 14 | print('usage: convert_dictionary.lua ') 15 | os.exit(1) 16 | end 17 | if not paths.filep(arg[1]) then 18 | print('error: file does not exit: ' .. arg[1]) 19 | os.exit(1) 20 | end 21 | 22 | dict = torch.load(arg[1]) 23 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 24 | assert(dst:match('.txt$')) 25 | 26 | f = io.open(dst, 'w') 27 | for idx, symbol in ipairs(dict.index_to_symbol) do 28 | if idx > dict.cutoff then 29 | break 30 | end 31 | f:write(symbol) 32 | f:write(' ') 33 | f:write(dict.index_to_freq[idx]) 34 | f:write('\n') 35 | end 36 | f:close() 37 | -------------------------------------------------------------------------------- /scripts/convert_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_model.lua 9 | require 'torch' 10 | local fairseq = require 'fairseq' 11 | 12 | model = torch.load(arg[1]) 13 | 14 | function find_weight_norm(container, module) 15 | for _, wn in ipairs(container:listModules()) do 16 | if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then 17 | return wn 18 | end 19 | end 20 | end 21 | 22 | function push_state(dict, key, module) 23 | if torch.type(module) == 'nn.Linear' then 24 | local wn = find_weight_norm(model.module, module) 25 | assert(wn) 26 | dict[key .. '.weight_v'] = wn.v:float() 27 | dict[key .. '.weight_g'] = wn.g:float() 28 | elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then 29 | local wn = find_weight_norm(model.module, module) 30 | assert(wn) 31 | local v = wn.v:float():view(wn.viewOut):transpose(2, 3) 32 | dict[key .. '.weight_v'] = v 33 | dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) 34 | else 35 | dict[key .. '.weight'] = module.weight:float() 36 | end 37 | if module.bias then 38 | dict[key .. '.bias'] = module.bias:float() 39 | end 40 | end 41 | 42 | encoder_dict = {} 43 | decoder_dict = {} 44 | combined_dict = {} 45 | 46 | function encoder_state(encoder) 47 | luts = encoder:findModules('nn.LookupTable') 48 | push_state(encoder_dict, 'embed_tokens', luts[1]) 49 | push_state(encoder_dict, 'embed_positions', luts[2]) 50 | 51 | fcs = encoder:findModules('nn.Linear') 52 | assert(#fcs >= 2) 53 | local nInputPlane = fcs[1].weight:size(1) 54 | push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) 55 | push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) 56 | 57 | for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do 58 | push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) 59 | if nInputPlane ~= module.weight:size(3) / 2 then 60 | push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 61 | end 62 | nInputPlane = module.weight:size(3) / 2 63 | end 64 | assert(#fcs == 0) 65 | end 66 | 67 | function decoder_state(decoder) 68 | luts = decoder:findModules('nn.LookupTable') 69 | push_state(decoder_dict, 'embed_tokens', luts[1]) 70 | push_state(decoder_dict, 'embed_positions', luts[2]) 71 | 72 | fcs = decoder:findModules('nn.Linear') 73 | local nInputPlane = fcs[1].weight:size(1) 74 | push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) 75 | push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) 76 | push_state(decoder_dict, 'fc3', fcs[#fcs]) 77 | 78 | table.remove(fcs, #fcs) 79 | table.remove(fcs, #fcs) 80 | 81 | for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do 82 | if nInputPlane ~= module.weight:size(3) / 2 then 83 | push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 84 | end 85 | nInputPlane = module.weight:size(3) / 2 86 | 87 | local prefix = 'attention.' .. tostring(i - 1) 88 | push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) 89 | push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) 90 | push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) 91 | end 92 | assert(#fcs == 0) 93 | end 94 | 95 | 96 | _encoder = model.module.modules[2] 97 | _decoder = model.module.modules[3] 98 | 99 | encoder_state(_encoder) 100 | decoder_state(_decoder) 101 | 102 | for k, v in pairs(encoder_dict) do 103 | combined_dict['encoder.' .. k] = v 104 | end 105 | for k, v in pairs(decoder_dict) do 106 | combined_dict['decoder.' .. k] = v 107 | end 108 | 109 | 110 | torch.save('state_dict.t7', combined_dict) 111 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | import argparse 11 | 12 | from fairseq.data import dictionary 13 | from fairseq.data import IndexedDataset 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser( 18 | description='writes text from binarized file to stdout') 19 | parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words') 20 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 21 | 22 | return parser 23 | 24 | 25 | def main(args): 26 | dict = dictionary.Dictionary.load(args.dict) 27 | ds = IndexedDataset(args.input, fix_lua_indexing=True) 28 | for tensor_line in ds: 29 | print(dict.string(tensor_line)) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = get_parser() 34 | args = parser.parse_args() 35 | main(args) 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from setuptools import setup, find_packages, Extension 10 | import sys 11 | 12 | 13 | if sys.version_info < (3,): 14 | sys.exit('Sorry, Python3 is required for fairseq.') 15 | 16 | with open('README.md') as f: 17 | readme = f.read() 18 | 19 | with open('LICENSE') as f: 20 | license = f.read() 21 | 22 | with open('requirements.txt') as f: 23 | reqs = f.read() 24 | 25 | 26 | bleu = Extension( 27 | 'fairseq.libbleu', 28 | sources=[ 29 | 'fairseq/clib/libbleu/libbleu.cpp', 30 | 'fairseq/clib/libbleu/module.cpp', 31 | ], 32 | extra_compile_args=['-std=c++11'], 33 | ) 34 | 35 | 36 | setup( 37 | name='fairseq', 38 | version='0.5.0', 39 | description='Facebook AI Research Sequence-to-Sequence Toolkit', 40 | long_description=readme, 41 | license=license, 42 | install_requires=reqs.strip().split('\n'), 43 | packages=find_packages(), 44 | ext_modules=[bleu], 45 | test_suite='tests', 46 | ) 47 | -------------------------------------------------------------------------------- /tests/test_average_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import collections 9 | import os 10 | import tempfile 11 | import unittest 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from scripts.average_checkpoints import average_checkpoints 17 | 18 | 19 | class TestAverageCheckpoints(unittest.TestCase): 20 | def test_average_checkpoints(self): 21 | params_0 = collections.OrderedDict( 22 | [ 23 | ('a', torch.DoubleTensor([100.0])), 24 | ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), 25 | ('c', torch.IntTensor([7, 8, 9])), 26 | ] 27 | ) 28 | params_1 = collections.OrderedDict( 29 | [ 30 | ('a', torch.DoubleTensor([1.0])), 31 | ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), 32 | ('c', torch.IntTensor([2, 2, 2])), 33 | ] 34 | ) 35 | params_avg = collections.OrderedDict( 36 | [ 37 | ('a', torch.DoubleTensor([50.5])), 38 | ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), 39 | # We expect truncation for integer division 40 | ('c', torch.IntTensor([4, 5, 5])), 41 | ] 42 | ) 43 | 44 | fd_0, path_0 = tempfile.mkstemp() 45 | fd_1, path_1 = tempfile.mkstemp() 46 | torch.save(collections.OrderedDict([('model', params_0)]), path_0) 47 | torch.save(collections.OrderedDict([('model', params_1)]), path_1) 48 | 49 | output = average_checkpoints([path_0, path_1])['model'] 50 | 51 | os.close(fd_0) 52 | os.remove(path_0) 53 | os.close(fd_1) 54 | os.remove(path_1) 55 | 56 | for (k_expected, v_expected), (k_out, v_out) in zip( 57 | params_avg.items(), output.items()): 58 | self.assertEqual( 59 | k_expected, k_out, 'Key mismatch - expected {} but found {}. ' 60 | '(Expected list of keys: {} vs actual list of keys: {})'.format( 61 | k_expected, k_out, params_avg.keys(), output.keys() 62 | ) 63 | ) 64 | np.testing.assert_allclose( 65 | v_expected.numpy(), 66 | v_out.numpy(), 67 | err_msg='Tensor value mismatch for key {}'.format(k_expected) 68 | ) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /tests/test_character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | 11 | from fairseq.data import Dictionary 12 | from fairseq.modules import CharacterTokenEmbedder 13 | 14 | 15 | class TestCharacterTokenEmbedder(unittest.TestCase): 16 | def test_character_token_embedder(self): 17 | vocab = Dictionary() 18 | vocab.add_symbol('hello') 19 | vocab.add_symbol('there') 20 | 21 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2) 22 | 23 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] 24 | max_len = max(len(s) for s in test_sents) 25 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) 26 | for i in range(len(test_sents)): 27 | input[i][0] = vocab.eos() 28 | for j in range(len(test_sents[i])): 29 | input[i][j + 1] = vocab.index(test_sents[i][j]) 30 | input[i][j + 2] = vocab.eos() 31 | embs = embedder(input) 32 | 33 | assert embs.size() == (len(test_sents), max_len + 2, 5) 34 | self.assertAlmostEqual(embs[0][0], embs[1][0]) 35 | self.assertAlmostEqual(embs[0][0], embs[0][-1]) 36 | self.assertAlmostEqual(embs[0][1], embs[2][1]) 37 | self.assertAlmostEqual(embs[0][3], embs[1][1]) 38 | 39 | embs.sum().backward() 40 | assert embedder.char_embeddings.weight.grad is not None 41 | 42 | def assertAlmostEqual(self, t1, t2): 43 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 44 | self.assertLess((t1 - t2).abs().max(), 1e-6) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_convtbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | from fairseq.modules import ConvTBC 11 | import torch.nn as nn 12 | 13 | 14 | class TestConvTBC(unittest.TestCase): 15 | 16 | def test_convtbc(self): 17 | # ksz, in_channels, out_channels 18 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) 19 | # out_channels, in_channels, ksz 20 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1) 21 | 22 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) 23 | conv_tbc.bias.data.copy_(conv1d.bias.data) 24 | 25 | input_tbc = torch.randn(7, 2, 4, requires_grad=True) 26 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2) 27 | input1d.requires_grad = True 28 | 29 | output_tbc = conv_tbc(input_tbc) 30 | output1d = conv1d(input1d) 31 | 32 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) 33 | 34 | grad_tbc = torch.randn(output_tbc.size()) 35 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() 36 | 37 | output_tbc.backward(grad_tbc) 38 | output1d.backward(grad1d) 39 | 40 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data) 41 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) 42 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data) 43 | 44 | def assertAlmostEqual(self, t1, t2): 45 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 46 | self.assertLess((t1 - t2).abs().max(), 1e-4) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/test_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import tempfile 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.data import Dictionary 14 | from fairseq.tokenizer import Tokenizer 15 | 16 | 17 | class TestDictionary(unittest.TestCase): 18 | 19 | def test_finalize(self): 20 | txt = [ 21 | 'A B C D', 22 | 'B C D', 23 | 'C D', 24 | 'D', 25 | ] 26 | ref_ids1 = list(map(torch.IntTensor, [ 27 | [4, 5, 6, 7, 2], 28 | [5, 6, 7, 2], 29 | [6, 7, 2], 30 | [7, 2], 31 | ])) 32 | ref_ids2 = list(map(torch.IntTensor, [ 33 | [7, 6, 5, 4, 2], 34 | [6, 5, 4, 2], 35 | [5, 4, 2], 36 | [4, 2], 37 | ])) 38 | 39 | # build dictionary 40 | d = Dictionary() 41 | for line in txt: 42 | Tokenizer.tokenize(line, d, add_if_not_exist=True) 43 | 44 | def get_ids(dictionary): 45 | ids = [] 46 | for line in txt: 47 | ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False)) 48 | return ids 49 | 50 | def assertMatch(ids, ref_ids): 51 | for toks, ref_toks in zip(ids, ref_ids): 52 | self.assertEqual(toks.size(), ref_toks.size()) 53 | self.assertEqual(0, (toks != ref_toks).sum().item()) 54 | 55 | ids = get_ids(d) 56 | assertMatch(ids, ref_ids1) 57 | 58 | # check finalized dictionary 59 | d.finalize() 60 | finalized_ids = get_ids(d) 61 | assertMatch(finalized_ids, ref_ids2) 62 | 63 | # write to disk and reload 64 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict: 65 | d.save(tmp_dict.name) 66 | d = Dictionary.load(tmp_dict.name) 67 | reload_ids = get_ids(d) 68 | assertMatch(reload_ids, ref_ids2) 69 | assertMatch(finalized_ids, reload_ids) 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /tests/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | from fairseq.data import iterators 11 | 12 | 13 | class TestIterators(unittest.TestCase): 14 | 15 | def test_counting_iterator(self): 16 | x = list(range(10)) 17 | itr = iterators.CountingIterator(x) 18 | self.assertTrue(itr.has_next()) 19 | self.assertEqual(next(itr), 0) 20 | self.assertEqual(next(itr), 1) 21 | itr.skip(3) 22 | self.assertEqual(next(itr), 5) 23 | itr.skip(3) 24 | self.assertEqual(next(itr), 9) 25 | self.assertFalse(itr.has_next()) 26 | 27 | 28 | if __name__ == '__main__': 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /tests/test_label_smoothing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import copy 10 | import unittest 11 | 12 | import torch 13 | 14 | from fairseq.criterions.cross_entropy import CrossEntropyCriterion 15 | from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion 16 | 17 | import tests.utils as test_utils 18 | 19 | 20 | class TestLabelSmoothing(unittest.TestCase): 21 | 22 | def setUp(self): 23 | # build dictionary 24 | self.d = test_utils.dummy_dictionary(3) 25 | vocab = len(self.d) 26 | self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens 27 | self.assertEqual(self.d.pad(), 1) 28 | self.assertEqual(self.d.eos(), 2) 29 | self.assertEqual(self.d.unk(), 3) 30 | pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841 31 | 32 | # build dataset 33 | self.data = [ 34 | # the first batch item has padding 35 | {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])}, 36 | {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])}, 37 | ] 38 | self.sample = next(test_utils.dummy_dataloader(self.data)) 39 | 40 | # build model 41 | self.args = argparse.Namespace() 42 | self.args.sentence_avg = False 43 | self.args.probs = torch.FloatTensor([ 44 | # pad eos unk w1 w2 w3 45 | [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05], 46 | [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10], 47 | [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15], 48 | ]).unsqueeze(0).expand(2, 3, 7) # add batch dimension 49 | self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d) 50 | self.model = self.task.build_model(self.args) 51 | 52 | def test_nll_loss(self): 53 | self.args.label_smoothing = 0.1 54 | nll_crit = CrossEntropyCriterion(self.args, self.task) 55 | smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 56 | nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) 57 | smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) 58 | self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6) 59 | self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6) 60 | 61 | def test_padding(self): 62 | self.args.label_smoothing = 0.1 63 | crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 64 | loss, _, logging_output = crit(self.model, self.sample) 65 | 66 | def get_one_no_padding(idx): 67 | # create a new sample with just a single batch item so that there's 68 | # no padding 69 | sample1 = next(test_utils.dummy_dataloader([self.data[idx]])) 70 | args1 = copy.copy(self.args) 71 | args1.probs = args1.probs[idx, :, :].unsqueeze(0) 72 | model1 = self.task.build_model(args1) 73 | loss1, _, _ = crit(model1, sample1) 74 | return loss1 75 | 76 | loss1 = get_one_no_padding(0) 77 | loss2 = get_one_no_padding(1) 78 | self.assertAlmostEqual(loss, loss1 + loss2) 79 | 80 | def test_reduction(self): 81 | self.args.label_smoothing = 0.1 82 | crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 83 | loss, _, logging_output = crit(self.model, self.sample, reduce=True) 84 | unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False) 85 | self.assertAlmostEqual(loss, unreduced_loss.sum()) 86 | 87 | def test_zero_eps(self): 88 | self.args.label_smoothing = 0.0 89 | nll_crit = CrossEntropyCriterion(self.args, self.task) 90 | smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 91 | nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) 92 | smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) 93 | self.assertAlmostEqual(nll_loss, smooth_loss) 94 | 95 | def assertAlmostEqual(self, t1, t2): 96 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 97 | self.assertLess((t1 - t2).abs().max(), 1e-6) 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /tests/test_sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.sequence_scorer import SequenceScorer 14 | 15 | import tests.utils as test_utils 16 | 17 | 18 | class TestSequenceScorer(unittest.TestCase): 19 | 20 | def test_sequence_scorer(self): 21 | # construct dummy dictionary 22 | d = test_utils.dummy_dictionary(vocab_size=2) 23 | self.assertEqual(d.pad(), 1) 24 | self.assertEqual(d.eos(), 2) 25 | self.assertEqual(d.unk(), 3) 26 | eos = d.eos() 27 | w1 = 4 28 | w2 = 5 29 | 30 | # construct dataloader 31 | data = [ 32 | { 33 | 'source': torch.LongTensor([w1, w2, eos]), 34 | 'target': torch.LongTensor([w1, w2, w1, eos]), 35 | }, 36 | { 37 | 'source': torch.LongTensor([w2, eos]), 38 | 'target': torch.LongTensor([w2, w1, eos]), 39 | }, 40 | { 41 | 'source': torch.LongTensor([w2, eos]), 42 | 'target': torch.LongTensor([w2, eos]), 43 | }, 44 | ] 45 | data_itr = test_utils.dummy_dataloader(data) 46 | 47 | # specify expected output probabilities 48 | args = argparse.Namespace() 49 | unk = 0. 50 | args.beam_probs = [ 51 | # step 0: 52 | torch.FloatTensor([ 53 | # eos w1 w2 54 | [0.0, unk, 0.6, 0.4], # sentence 1 55 | [0.0, unk, 0.4, 0.6], # sentence 2 56 | [0.0, unk, 0.7, 0.3], # sentence 3 57 | ]), 58 | # step 1: 59 | torch.FloatTensor([ 60 | # eos w1 w2 61 | [0.0, unk, 0.2, 0.7], # sentence 1 62 | [0.0, unk, 0.8, 0.2], # sentence 2 63 | [0.7, unk, 0.1, 0.2], # sentence 3 64 | ]), 65 | # step 2: 66 | torch.FloatTensor([ 67 | # eos w1 w2 68 | [0.10, unk, 0.50, 0.4], # sentence 1 69 | [0.15, unk, 0.15, 0.7], # sentence 2 70 | [0.00, unk, 0.00, 0.0], # sentence 3 71 | ]), 72 | # step 3: 73 | torch.FloatTensor([ 74 | # eos w1 w2 75 | [0.9, unk, 0.05, 0.05], # sentence 1 76 | [0.0, unk, 0.00, 0.0], # sentence 2 77 | [0.0, unk, 0.00, 0.0], # sentence 3 78 | ]), 79 | ] 80 | expected_scores = [ 81 | [0.6, 0.7, 0.5, 0.9], # sentence 1 82 | [0.6, 0.8, 0.15], # sentence 2 83 | [0.3, 0.7], # sentence 3 84 | ] 85 | 86 | task = test_utils.TestTranslationTask.setup_task(args, d, d) 87 | model = task.build_model(args) 88 | scorer = SequenceScorer([model], task.target_dictionary) 89 | for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr): 90 | self.assertHypoTokens(hypos[0], data[id]['target']) 91 | self.assertHypoScore(hypos[0], expected_scores[id]) 92 | 93 | def assertHypoTokens(self, hypo, tokens): 94 | self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens)) 95 | 96 | def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): 97 | pos_scores = torch.FloatTensor(pos_probs).log() 98 | self.assertAlmostEqual(hypo['positional_scores'], pos_scores) 99 | self.assertEqual(pos_scores.numel(), hypo['tokens'].numel()) 100 | score = pos_scores.sum() 101 | if normalized: 102 | score /= pos_scores.numel()**lenpen 103 | self.assertLess(abs(score - hypo['score']), 1e-6) 104 | 105 | def assertAlmostEqual(self, t1, t2): 106 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 107 | self.assertLess((t1 - t2).abs().max(), 1e-4) 108 | 109 | def assertTensorEqual(self, t1, t2): 110 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 111 | self.assertEqual(t1.ne(t2).long().sum(), 0) 112 | 113 | 114 | if __name__ == '__main__': 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import contextlib 9 | from io import StringIO 10 | import unittest 11 | from unittest.mock import MagicMock, patch 12 | 13 | import torch 14 | 15 | from fairseq import data 16 | 17 | import train 18 | 19 | 20 | def mock_trainer(epoch, num_updates, iterations_in_epoch): 21 | trainer = MagicMock() 22 | trainer.load_checkpoint.return_value = { 23 | 'train_iterator': { 24 | 'epoch': epoch, 25 | 'iterations_in_epoch': iterations_in_epoch, 26 | 'shuffle': False, 27 | }, 28 | } 29 | trainer.get_num_updates.return_value = num_updates 30 | return trainer 31 | 32 | 33 | def mock_dict(): 34 | d = MagicMock() 35 | d.pad.return_value = 1 36 | d.eos.return_value = 2 37 | d.unk.return_value = 3 38 | return d 39 | 40 | 41 | def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): 42 | tokens = torch.LongTensor(list(range(epoch_size))) 43 | tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False) 44 | trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) 45 | dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) 46 | epoch_itr = data.EpochBatchIterator( 47 | dataset=dataset, 48 | collate_fn=dataset.collater, 49 | batch_sampler=[[i] for i in range(epoch_size)], 50 | ) 51 | return trainer, epoch_itr 52 | 53 | 54 | class TestLoadCheckpoint(unittest.TestCase): 55 | 56 | def setUp(self): 57 | self.args_mock = MagicMock() 58 | self.args_mock.optimizer_overrides = '{}' 59 | self.patches = { 60 | 'os.makedirs': MagicMock(), 61 | 'os.path.join': MagicMock(), 62 | 'os.path.isfile': MagicMock(return_value=True), 63 | } 64 | self.applied_patches = [patch(p, d) for p, d in self.patches.items()] 65 | [p.start() for p in self.applied_patches] 66 | 67 | 68 | def test_load_partial_checkpoint(self): 69 | with contextlib.redirect_stdout(StringIO()): 70 | trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) 71 | 72 | train.load_checkpoint(self.args_mock, trainer, epoch_itr) 73 | self.assertEqual(epoch_itr.epoch, 2) 74 | self.assertEqual(epoch_itr.iterations_in_epoch, 50) 75 | 76 | itr = epoch_itr.next_epoch_itr(shuffle=False) 77 | self.assertEqual(epoch_itr.epoch, 2) 78 | self.assertEqual(epoch_itr.iterations_in_epoch, 50) 79 | 80 | self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50) 81 | self.assertEqual(epoch_itr.iterations_in_epoch, 51) 82 | 83 | def test_load_full_checkpoint(self): 84 | with contextlib.redirect_stdout(StringIO()): 85 | trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) 86 | 87 | train.load_checkpoint(self.args_mock, trainer, epoch_itr) 88 | itr = epoch_itr.next_epoch_itr(shuffle=False) 89 | 90 | self.assertEqual(epoch_itr.epoch, 3) 91 | self.assertEqual(epoch_itr.iterations_in_epoch, 0) 92 | self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) 93 | 94 | def test_load_no_checkpoint(self): 95 | with contextlib.redirect_stdout(StringIO()): 96 | trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) 97 | self.patches['os.path.isfile'].return_value = False 98 | 99 | train.load_checkpoint(self.args_mock, trainer, epoch_itr) 100 | itr = epoch_itr.next_epoch_itr(shuffle=False) 101 | 102 | self.assertEqual(epoch_itr.epoch, 1) 103 | self.assertEqual(epoch_itr.iterations_in_epoch, 0) 104 | self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) 105 | 106 | def tearDown(self): 107 | patch.stopall() 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fairseq import utils 13 | 14 | 15 | class TestUtils(unittest.TestCase): 16 | 17 | def test_convert_padding_direction(self): 18 | pad = 1 19 | left_pad = torch.LongTensor([ 20 | [2, 3, 4, 5, 6], 21 | [1, 7, 8, 9, 10], 22 | [1, 1, 1, 11, 12], 23 | ]) 24 | right_pad = torch.LongTensor([ 25 | [2, 3, 4, 5, 6], 26 | [7, 8, 9, 10, 1], 27 | [11, 12, 1, 1, 1], 28 | ]) 29 | 30 | self.assertAlmostEqual( 31 | right_pad, 32 | utils.convert_padding_direction( 33 | left_pad, 34 | pad, 35 | left_to_right=True, 36 | ), 37 | ) 38 | self.assertAlmostEqual( 39 | left_pad, 40 | utils.convert_padding_direction( 41 | right_pad, 42 | pad, 43 | right_to_left=True, 44 | ), 45 | ) 46 | 47 | def test_make_positions(self): 48 | pad = 1 49 | left_pad_input = torch.LongTensor([ 50 | [9, 9, 9, 9, 9], 51 | [1, 9, 9, 9, 9], 52 | [1, 1, 1, 9, 9], 53 | ]) 54 | left_pad_output = torch.LongTensor([ 55 | [2, 3, 4, 5, 6], 56 | [1, 2, 3, 4, 5], 57 | [1, 1, 1, 2, 3], 58 | ]) 59 | right_pad_input = torch.LongTensor([ 60 | [9, 9, 9, 9, 9], 61 | [9, 9, 9, 9, 1], 62 | [9, 9, 1, 1, 1], 63 | ]) 64 | right_pad_output = torch.LongTensor([ 65 | [2, 3, 4, 5, 6], 66 | [2, 3, 4, 5, 1], 67 | [2, 3, 1, 1, 1], 68 | ]) 69 | 70 | self.assertAlmostEqual( 71 | left_pad_output, 72 | utils.make_positions(left_pad_input, pad, left_pad=True), 73 | ) 74 | self.assertAlmostEqual( 75 | right_pad_output, 76 | utils.make_positions(right_pad_input, pad, left_pad=False), 77 | ) 78 | 79 | def assertAlmostEqual(self, t1, t2): 80 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 81 | self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) 82 | 83 | 84 | if __name__ == '__main__': 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | from fairseq import utils 11 | from fairseq.data import Dictionary 12 | from fairseq.data.language_pair_dataset import collate 13 | from fairseq.models import ( 14 | FairseqEncoder, 15 | FairseqIncrementalDecoder, 16 | FairseqModel, 17 | ) 18 | from fairseq.tasks import FairseqTask 19 | 20 | 21 | def dummy_dictionary(vocab_size, prefix='token_'): 22 | d = Dictionary() 23 | for i in range(vocab_size): 24 | token = prefix + str(i) 25 | d.add_symbol(token) 26 | d.finalize(padding_factor=1) # don't add extra padding symbols 27 | return d 28 | 29 | 30 | def dummy_dataloader( 31 | samples, 32 | padding_idx=1, 33 | eos_idx=2, 34 | batch_size=None, 35 | ): 36 | if batch_size is None: 37 | batch_size = len(samples) 38 | 39 | # add any missing data to samples 40 | for i, sample in enumerate(samples): 41 | if 'id' not in sample: 42 | sample['id'] = i 43 | 44 | # create dataloader 45 | dataset = TestDataset(samples) 46 | dataloader = torch.utils.data.DataLoader( 47 | dataset, 48 | batch_size=batch_size, 49 | collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)), 50 | ) 51 | return iter(dataloader) 52 | 53 | 54 | class TestDataset(torch.utils.data.Dataset): 55 | 56 | def __init__(self, data): 57 | super().__init__() 58 | self.data = data 59 | 60 | def __getitem__(self, index): 61 | return self.data[index] 62 | 63 | def __len__(self): 64 | return len(self.data) 65 | 66 | 67 | class TestTranslationTask(FairseqTask): 68 | 69 | def __init__(self, args, src_dict, tgt_dict, model): 70 | super().__init__(args) 71 | self.src_dict = src_dict 72 | self.tgt_dict = tgt_dict 73 | self.model = model 74 | 75 | @classmethod 76 | def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None): 77 | return cls(args, src_dict, tgt_dict, model) 78 | 79 | def build_model(self, args): 80 | return TestModel.build_model(args, self) 81 | 82 | @property 83 | def source_dictionary(self): 84 | return self.src_dict 85 | 86 | @property 87 | def target_dictionary(self): 88 | return self.tgt_dict 89 | 90 | 91 | class TestModel(FairseqModel): 92 | def __init__(self, encoder, decoder): 93 | super().__init__(encoder, decoder) 94 | 95 | @classmethod 96 | def build_model(cls, args, task): 97 | encoder = TestEncoder(args, task.source_dictionary) 98 | decoder = TestIncrementalDecoder(args, task.target_dictionary) 99 | return cls(encoder, decoder) 100 | 101 | 102 | class TestEncoder(FairseqEncoder): 103 | def __init__(self, args, dictionary): 104 | super().__init__(dictionary) 105 | self.args = args 106 | 107 | def forward(self, src_tokens, src_lengths): 108 | return src_tokens 109 | 110 | def reorder_encoder_out(self, encoder_out, new_order): 111 | return encoder_out.index_select(0, new_order) 112 | 113 | 114 | class TestIncrementalDecoder(FairseqIncrementalDecoder): 115 | def __init__(self, args, dictionary): 116 | super().__init__(dictionary) 117 | assert hasattr(args, 'beam_probs') or hasattr(args, 'probs') 118 | args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100) 119 | self.args = args 120 | 121 | def forward(self, prev_output_tokens, encoder_out, incremental_state=None): 122 | if incremental_state is not None: 123 | prev_output_tokens = prev_output_tokens[:, -1:] 124 | bbsz = prev_output_tokens.size(0) 125 | vocab = len(self.dictionary) 126 | src_len = encoder_out.size(1) 127 | tgt_len = prev_output_tokens.size(1) 128 | 129 | # determine number of steps 130 | if incremental_state is not None: 131 | # cache step number 132 | step = utils.get_incremental_state(self, incremental_state, 'step') 133 | if step is None: 134 | step = 0 135 | utils.set_incremental_state(self, incremental_state, 'step', step + 1) 136 | steps = [step] 137 | else: 138 | steps = list(range(tgt_len)) 139 | 140 | # define output in terms of raw probs 141 | if hasattr(self.args, 'probs'): 142 | assert self.args.probs.dim() == 3, \ 143 | 'expected probs to have size bsz*steps*vocab' 144 | probs = self.args.probs.index_select(1, torch.LongTensor(steps)) 145 | else: 146 | probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() 147 | for i, step in enumerate(steps): 148 | # args.beam_probs gives the probability for every vocab element, 149 | # starting with eos, then unknown, and then the rest of the vocab 150 | if step < len(self.args.beam_probs): 151 | probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step] 152 | else: 153 | probs[:, i, self.dictionary.eos()] = 1.0 154 | 155 | # random attention 156 | attn = torch.rand(bbsz, tgt_len, src_len) 157 | 158 | return probs, attn 159 | 160 | def get_normalized_probs(self, net_output, log_probs, _): 161 | # the decoder returns probabilities directly 162 | probs = net_output[0] 163 | if log_probs: 164 | return probs.log() 165 | else: 166 | return probs 167 | 168 | def max_positions(self): 169 | return self.args.max_decoder_positions 170 | --------------------------------------------------------------------------------