├── LICENSE ├── README.md ├── encdec ├── LICENSE ├── README.md ├── distributed_train.py ├── docs │ ├── Makefile │ ├── _static │ │ └── theme_overrides.css │ ├── command_line_tools.rst │ ├── conf.py │ ├── criterions.rst │ ├── data.rst │ ├── docutils.conf │ ├── getting_started.rst │ ├── index.rst │ ├── lr_scheduler.rst │ ├── make.bat │ ├── models.rst │ ├── modules.rst │ ├── optim.rst │ ├── overview.rst │ ├── requirements.txt │ ├── tasks.rst │ ├── tutorial_classifying_names.rst │ └── tutorial_simple_lstm.rst ├── eval_lm.py ├── examples │ ├── .gitignore │ ├── language_model │ │ ├── README.md │ │ └── prepare-wikitext-103.sh │ ├── stories │ │ └── README.md │ └── translation │ │ ├── README.md │ │ ├── prepare-iwslt14.sh │ │ ├── prepare-wmt14en2de.sh │ │ └── prepare-wmt14en2fr.sh ├── fairseq │ ├── __init__.py │ ├── bleu.py │ ├── clib │ │ └── libbleu │ │ │ ├── libbleu.cpp │ │ │ └── module.cpp │ ├── criterions │ │ ├── __init__.py │ │ ├── adaptive_loss.py │ │ ├── cross_entropy.py │ │ ├── fairseq_criterion.py │ │ └── label_smoothed_cross_entropy.py │ ├── data │ │ ├── __init__.py │ │ ├── append_eos_dataset.py │ │ ├── backtranslation_dataset.py │ │ ├── concat_dataset.py │ │ ├── data_utils.py │ │ ├── dictionary.py │ │ ├── fairseq_dataset.py │ │ ├── indexed_dataset.py │ │ ├── iterators.py │ │ ├── language_pair_dataset.py │ │ ├── monolingual_dataset.py │ │ ├── noising.py │ │ └── token_block_dataset.py │ ├── distributed_utils.py │ ├── meters.py │ ├── models │ │ ├── __init__.py │ │ ├── composite_encoder.py │ │ ├── distributed_fairseq_model.py │ │ ├── fairseq_decoder.py │ │ ├── fairseq_encoder.py │ │ ├── fairseq_incremental_decoder.py │ │ ├── fairseq_model.py │ │ ├── fconv.py │ │ ├── fconv_self_att.py │ │ ├── lstm.py │ │ └── transformer.py │ ├── modules │ │ ├── __init__.py │ │ ├── adaptive_softmax.py │ │ ├── beamable_mm.py │ │ ├── character_token_embedder.py │ │ ├── conv_tbc.py │ │ ├── downsampled_multihead_attention.py │ │ ├── grad_multiply.py │ │ ├── highway.py │ │ ├── learned_positional_embedding.py │ │ ├── linearized_convolution.py │ │ ├── multihead_attention.py │ │ ├── scalar_bias.py │ │ └── sinusoidal_positional_embedding.py │ ├── multiprocessing_pdb.py │ ├── optim │ │ ├── __init__.py │ │ ├── adagrad.py │ │ ├── adam.py │ │ ├── fairseq_optimizer.py │ │ ├── fp16_optimizer.py │ │ ├── lr_scheduler │ │ │ ├── __init__.py │ │ │ ├── cosine_lr_scheduler.py │ │ │ ├── fairseq_lr_scheduler.py │ │ │ ├── fixed_schedule.py │ │ │ ├── inverse_square_root_schedule.py │ │ │ ├── reduce_lr_on_plateau.py │ │ │ └── triangular_lr_scheduler.py │ │ ├── nag.py │ │ └── sgd.py │ ├── options.py │ ├── progress_bar.py │ ├── search.py │ ├── sequence_generator.py │ ├── sequence_scorer.py │ ├── tasks │ │ ├── __init__.py │ │ ├── fairseq_task.py │ │ ├── language_modeling.py │ │ └── translation.py │ ├── tokenizer.py │ ├── trainer.py │ └── utils.py ├── generate.py ├── interactive.py ├── multiprocessing_train.py ├── preprocess.py ├── requirements.txt ├── rerank.py ├── score.py ├── scripts │ ├── __init__.py │ ├── average_checkpoints.py │ ├── build_sym_alignment.py │ ├── convert_dictionary.lua │ ├── convert_model.lua │ └── read_binarized.py ├── setup.py ├── tests │ ├── __init__.py │ ├── test_average_checkpoints.py │ ├── test_backtranslation_dataset.py │ ├── test_binaries.py │ ├── test_character_token_embedder.py │ ├── test_convtbc.py │ ├── test_dictionary.py │ ├── test_iterators.py │ ├── test_label_smoothing.py │ ├── test_noising.py │ ├── test_reproducibility.py │ ├── test_sequence_generator.py │ ├── test_sequence_scorer.py │ ├── test_train.py │ ├── test_utils.py │ └── utils.py └── train.py └── eval ├── LICENSE ├── README.md ├── calculate_variance_from_fixlength.py ├── eval.sh ├── make_rouge.py └── prepare4rouge-simple.pl /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Sho Takase 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Positional Encoding to Control Output Sequence Length 2 | 3 | This repository contains source files we used in our paper 4 | >[Positional Encoding to Control Output Sequence Length](https://www.aclweb.org/anthology/N19-1401) 5 | 6 | >Sho Takase, Naoaki Okazaki 7 | 8 | > Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies 9 | 10 | 11 | ## Requirements 12 | 13 | - Python 3.6 or later for training 14 | - Python 2.7 for calculating rouge 15 | - PyTorch 0.4 16 | - To use new version PyTorch (e.g., 1.4.0), please use [this code](https://github.com/takase/alone_seq2seq) without one-emb option. 17 | 18 | ## Test data 19 | 20 | Test data used in our paper for each length 21 | 22 | - [https://drive.google.com/open?id=1teets0SZ82cdwQG0s454Y7JFuoutOawb](https://drive.google.com/open?id=1teets0SZ82cdwQG0s454Y7JFuoutOawb) 23 | - Each file contains ```SOURCE PART tab HEADLINE``` 24 | 25 | ## Pre-trained model 26 | 27 | The following file contains pre-trained LRPE + PE model in English dataset. This model outputs ``` @@@@ ``` as a space, namely, a segmentation marker of words. 28 | 29 | The file also contains BPE code to split a plane English text into BPE with [this code](https://github.com/rsennrich/subword-nmt). 30 | 31 | [https://drive.google.com/file/d/15Sy8rv6Snw6Nso7T5MxYHSAZDdieXpE7/view?usp=sharing](https://drive.google.com/file/d/15Sy8rv6Snw6Nso7T5MxYHSAZDdieXpE7/view?usp=sharing) 32 | 33 | ## Acknowledgements 34 | 35 | A large portion of this repo is borrowed from the following repos: [https://github.com/pytorch/fairseq](https://github.com/pytorch/fairseq) and [https://github.com/facebookarchive/NAMAS](https://github.com/facebookarchive/NAMAS). 36 | -------------------------------------------------------------------------------- /encdec/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fairseq software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /encdec/README.md: -------------------------------------------------------------------------------- 1 | ## Preprocessing 2 | 3 | - Construction of binarized data with shared vocabulary 4 | 5 | - Input data is plain text such as following example 6 | 7 | ``` 8 | australia 's current account deficit shrunk by a record 7.07 billion dollars -lrb- 6.04 billion us -rrb- in the june quarter due to soaring commodity prices , figures released monday showed . 9 | at least two people were killed in a suspected bomb attack on a passenger bus in the strife-torn southern philippines on monday , the military said . 10 | australian shares closed down 0.3 percent monday following a weak lead from the united states and lower commodity prices , dealers said . 11 | ``` 12 | 13 | ``` 14 | python preprocess.py --source-lang SOURCE_SUFFIX --target-lang TARGET_SUFFIX 15 | --trainpref PREFIX_PATH_TO_TRAIN_DATA --validpref PREFIX_PATH_TO_VALID_DATA 16 | --joined-dictionary --destdir PREPROCESS_PATH 17 | ``` 18 | 19 | - If source file name of training data is text.source and target file name of training data is text.target, please set SOURCE_SUFFIX=source, TARGET_SUFFIX=target, PREFIX_PATH_TO_TRAIN_DATA=text 20 | 21 | - Preprocessing to test file 22 | 23 | ``` 24 | python preprocess.py --source-lang SOURCE_SUFFIX --target-lang TARGET_SUFFIX 25 | --tgtdict PATH_TO_TARGET_DICT --srcdict PATH_TO_SOURCE_DICT 26 | --testpref PREFIX_PATH_TO_TEST_DATA --destdir PREPROCESS_TEST_PATH 27 | ``` 28 | 29 | ## Training 30 | 31 | - E.g., training Transformer+LRPE+PE on 4 GPU machine 32 | 33 | - +LRPE: --represent-length-by-lrpe 34 | 35 | - +LDPE: --represent-length-by-ldpe 36 | 37 | - +PE: --ordinary-sinpos 38 | 39 | ``` 40 | python train.py PREPROCESS_PATH --source-lang SOURCE_SUFFIX --target-lang TARGET_SUFFIX 41 | --arch transformer_wmt_en_de --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 42 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.001 --min-lr 1e-09 43 | --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 44 | --max-tokens 3584 --seed 2723 --max-epoch 100 --update-freq 16 --share-all-embeddings 45 | --represent-length-by-lrpe --ordinary-sinpos --save-dir PATH_TO_SAVE_MODEL 46 | ``` 47 | 48 | - If you run the training process on 1 GPU, please modify update freq 16 -> 64 49 | 50 | - Averaging last 10 checkpoints 51 | 52 | ``` 53 | python scripts/average_checkpoints.py --inputs PATH_TO_SAVE_MODEL --num-epoch-checkpoints 10 --output PATH_TO_AVERAGED_MODEL 54 | ``` 55 | 56 | ## Generation 57 | 58 | 1. Generate headlines in the constraint of 75 characters 59 | 60 | ``` 61 | python generate.py PREPROCESS_TEST_PATH --source-lang SOURCE_SUFFIX --target-lang TARGET_SUFFIX 62 | --path PATH_TO_AVERAGED_MODEL --desired-length 75 --batch-size 32 --beam 5 63 | | grep '^H' | sed 's/^H\-//g' | sort -t 'TAB' -k1,1 -n | cut -f 3- 64 | ``` 65 | 66 | 2. Generate n-best headlines and re-ranking 67 | 68 | - Generate n-best headlines (n = 20 in the following example) 69 | 70 | ``` 71 | python generate.py PREPROCESS_TEST_PATH --source-lang SOURCE_SUFFIX --target-lang TARGET_SUFFIX 72 | --path PATH_TO_AVERAGED_MODEL --batch-size 32 --beam 20 --nbest 20 --desired-length 75 > nbest.txt 73 | ``` 74 | 75 | - Re-ranking n-best headlines 76 | 77 | ``` 78 | python rerank.py --cand nbest.txt -m --source SOURCE_FILE 79 | ``` 80 | -------------------------------------------------------------------------------- /encdec/distributed_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import os 10 | import socket 11 | import subprocess 12 | 13 | from train import main as single_process_main 14 | from fairseq import distributed_utils, options 15 | 16 | 17 | def main(args): 18 | if args.distributed_init_method is None and args.distributed_port > 0: 19 | # We can determine the init method automatically for Slurm. 20 | node_list = os.environ.get('SLURM_JOB_NODELIST') 21 | if node_list is not None: 22 | try: 23 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) 24 | args.distributed_init_method = 'tcp://{host}:{port}'.format( 25 | host=hostnames.split()[0].decode('utf-8'), 26 | port=args.distributed_port) 27 | args.distributed_rank = int(os.environ.get('SLURM_PROCID')) 28 | args.device_id = int(os.environ.get('SLURM_LOCALID')) 29 | except subprocess.CalledProcessError as e: # scontrol failed 30 | raise e 31 | except FileNotFoundError as e: # Slurm is not installed 32 | pass 33 | if args.distributed_init_method is None and args.distributed_port is None: 34 | raise ValueError('--distributed-init-method or --distributed-port ' 35 | 'must be specified for distributed training') 36 | 37 | args.distributed_rank = distributed_utils.distributed_init(args) 38 | print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) 39 | single_process_main(args) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = options.get_training_parser() 44 | args = options.parse_args_and_arch(parser) 45 | main(args) 46 | -------------------------------------------------------------------------------- /encdec/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = fairseq 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /encdec/docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | .wy-table-responsive table td kbd { 2 | white-space: nowrap; 3 | } 4 | .wy-table-responsive table td { 5 | white-space: normal !important; 6 | } 7 | .wy-table-responsive { 8 | overflow: visible !important; 9 | } 10 | -------------------------------------------------------------------------------- /encdec/docs/command_line_tools.rst: -------------------------------------------------------------------------------- 1 | .. _Command-line Tools: 2 | 3 | Command-line Tools 4 | ================== 5 | 6 | Fairseq provides several command-line tools for training and evaluating models: 7 | 8 | - :ref:`preprocess.py`: Data pre-processing: build vocabularies and binarize training data 9 | - :ref:`train.py`: Train a new model on one or multiple GPUs 10 | - :ref:`generate.py`: Translate pre-processed data with a trained model 11 | - :ref:`interactive.py`: Translate raw text with a trained model 12 | - :ref:`score.py`: BLEU scoring of generated translations against reference translations 13 | - :ref:`eval_lm.py`: Language model evaluation 14 | 15 | 16 | .. _preprocess.py: 17 | 18 | preprocess.py 19 | ~~~~~~~~~~~~~ 20 | .. automodule:: preprocess 21 | 22 | .. argparse:: 23 | :module: preprocess 24 | :func: get_parser 25 | :prog: preprocess.py 26 | 27 | 28 | .. _train.py: 29 | 30 | train.py 31 | ~~~~~~~~ 32 | .. automodule:: train 33 | 34 | .. argparse:: 35 | :module: fairseq.options 36 | :func: get_training_parser 37 | :prog: train.py 38 | 39 | 40 | .. _generate.py: 41 | 42 | generate.py 43 | ~~~~~~~~~~~ 44 | .. automodule:: generate 45 | 46 | .. argparse:: 47 | :module: fairseq.options 48 | :func: get_generation_parser 49 | :prog: generate.py 50 | 51 | 52 | .. _interactive.py: 53 | 54 | interactive.py 55 | ~~~~~~~~~~~~~~ 56 | .. automodule:: interactive 57 | 58 | .. argparse:: 59 | :module: fairseq.options 60 | :func: get_interactive_generation_parser 61 | :prog: interactive.py 62 | 63 | 64 | .. _score.py: 65 | 66 | score.py 67 | ~~~~~~~~ 68 | .. automodule:: score 69 | 70 | .. argparse:: 71 | :module: score 72 | :func: get_parser 73 | :prog: score.py 74 | 75 | 76 | .. _eval_lm.py: 77 | 78 | eval_lm.py 79 | ~~~~~~~~~~ 80 | .. automodule:: eval_lm 81 | 82 | .. argparse:: 83 | :module: fairseq.options 84 | :func: get_eval_lm_parser 85 | :prog: eval_lm.py 86 | -------------------------------------------------------------------------------- /encdec/docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # fairseq documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Aug 17 21:45:30 2018. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | 20 | import os 21 | import sys 22 | 23 | # source code directory, relative to this file, for sphinx-autobuild 24 | sys.path.insert(0, os.path.abspath('..')) 25 | 26 | source_suffix = ['.rst'] 27 | 28 | # -- General configuration ------------------------------------------------ 29 | 30 | # If your documentation needs a minimal Sphinx version, state it here. 31 | # 32 | # needs_sphinx = '1.0' 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.intersphinx', 40 | 'sphinx.ext.viewcode', 41 | 'sphinx.ext.napoleon', 42 | 'sphinxarg.ext', 43 | ] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | # General information about the project. 52 | project = 'fairseq' 53 | copyright = '2018, Facebook AI Research (FAIR)' 54 | author = 'Facebook AI Research (FAIR)' 55 | 56 | github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' 57 | 58 | # The version info for the project you're documenting, acts as replacement for 59 | # |version| and |release|, also used in various other places throughout the 60 | # built documents. 61 | # 62 | # The short X.Y version. 63 | version = '0.6.0' 64 | # The full version, including alpha/beta/rc tags. 65 | release = '0.6.0' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = None 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This patterns also effect to html_static_path and html_extra_path 77 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = 'sphinx' 81 | highlight_language = 'python' 82 | 83 | # If true, `todo` and `todoList` produce output, else they produce nothing. 84 | todo_include_todos = False 85 | 86 | 87 | # -- Options for HTML output ---------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | html_theme = 'sphinx_rtd_theme' 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ['_static'] 104 | 105 | html_context = { 106 | 'css_files': [ 107 | '_static/theme_overrides.css', # override wide tables in RTD theme 108 | ], 109 | } 110 | 111 | # Custom sidebar templates, must be a dictionary that maps document names 112 | # to template names. 113 | # 114 | # This is required for the alabaster theme 115 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 116 | #html_sidebars = { 117 | # '**': [ 118 | # 'about.html', 119 | # 'navigation.html', 120 | # 'relations.html', # needs 'show_related': True theme option to display 121 | # 'searchbox.html', 122 | # 'donate.html', 123 | # ] 124 | #} 125 | 126 | 127 | # Example configuration for intersphinx: refer to the Python standard library. 128 | intersphinx_mapping = { 129 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 130 | 'python': ('https://docs.python.org/', None), 131 | 'torch': ('https://pytorch.org/docs/master/', None), 132 | } 133 | -------------------------------------------------------------------------------- /encdec/docs/criterions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Criterions: 5 | 6 | Criterions 7 | ========== 8 | 9 | .. automodule:: fairseq.criterions 10 | :members: 11 | .. autoclass:: fairseq.criterions.FairseqCriterion 12 | :members: 13 | :undoc-members: 14 | -------------------------------------------------------------------------------- /encdec/docs/data.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.data 5 | 6 | Data Loading and Utilities 7 | ========================== 8 | 9 | .. _datasets: 10 | 11 | Datasets 12 | -------- 13 | 14 | **Datasets** define the data format and provide helpers for creating 15 | mini-batches. 16 | 17 | .. autoclass:: fairseq.data.FairseqDataset 18 | :members: 19 | .. autoclass:: fairseq.data.LanguagePairDataset 20 | :members: 21 | .. autoclass:: fairseq.data.MonolingualDataset 22 | :members: 23 | 24 | 25 | Dictionary 26 | ---------- 27 | 28 | .. autoclass:: fairseq.data.Dictionary 29 | :members: 30 | 31 | 32 | Iterators 33 | --------- 34 | 35 | .. autoclass:: fairseq.data.CountingIterator 36 | :members: 37 | .. autoclass:: fairseq.data.EpochBatchIterator 38 | :members: 39 | .. autoclass:: fairseq.data.GroupedIterator 40 | :members: 41 | .. autoclass:: fairseq.data.ShardedIterator 42 | :members: 43 | -------------------------------------------------------------------------------- /encdec/docs/docutils.conf: -------------------------------------------------------------------------------- 1 | [writers] 2 | option-limit=0 3 | -------------------------------------------------------------------------------- /encdec/docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fairseq documentation master file, created by 2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pytorch/fairseq 7 | 8 | 9 | fairseq documentation 10 | ===================== 11 | 12 | Fairseq is a sequence modeling toolkit written in `PyTorch 13 | `_ that allows researchers and developers to 14 | train custom models for translation, summarization, language modeling and other 15 | text generation tasks. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Getting Started 20 | 21 | getting_started 22 | command_line_tools 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Extending Fairseq 27 | 28 | overview 29 | tutorial_simple_lstm 30 | tutorial_classifying_names 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | :caption: Library Reference 35 | 36 | tasks 37 | models 38 | criterions 39 | optim 40 | lr_scheduler 41 | data 42 | modules 43 | 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`search` 50 | -------------------------------------------------------------------------------- /encdec/docs/lr_scheduler.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Learning Rate Schedulers: 5 | 6 | Learning Rate Schedulers 7 | ======================== 8 | 9 | TODO 10 | 11 | .. automodule:: fairseq.optim.lr_scheduler 12 | :members: 13 | -------------------------------------------------------------------------------- /encdec/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fairseq 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /encdec/docs/models.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.models 5 | 6 | .. _Models: 7 | 8 | Models 9 | ====== 10 | 11 | A Model defines the neural network's ``forward()`` method and encapsulates all 12 | of the learnable parameters in the network. Each model also provides a set of 13 | named *architectures* that define the precise network configuration (e.g., 14 | embedding dimension, number of layers, etc.). 15 | 16 | Both the model type and architecture are selected via the ``--arch`` 17 | command-line argument. Once selected, a model may expose additional command-line 18 | arguments for further configuration. 19 | 20 | .. note:: 21 | 22 | All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends 23 | :class:`torch.nn.Module`. Thus any fairseq Model can be used as a 24 | stand-alone Module in other PyTorch code. 25 | 26 | 27 | Convolutional Neural Networks (CNN) 28 | ----------------------------------- 29 | 30 | .. module:: fairseq.models.fconv 31 | .. autoclass:: fairseq.models.fconv.FConvModel 32 | :members: 33 | .. autoclass:: fairseq.models.fconv.FConvEncoder 34 | :members: 35 | :undoc-members: 36 | .. autoclass:: fairseq.models.fconv.FConvDecoder 37 | :members: 38 | 39 | 40 | Long Short-Term Memory (LSTM) networks 41 | -------------------------------------- 42 | 43 | .. module:: fairseq.models.lstm 44 | .. autoclass:: fairseq.models.lstm.LSTMModel 45 | :members: 46 | .. autoclass:: fairseq.models.lstm.LSTMEncoder 47 | :members: 48 | .. autoclass:: fairseq.models.lstm.LSTMDecoder 49 | :members: 50 | 51 | 52 | Transformer (self-attention) networks 53 | ------------------------------------- 54 | 55 | .. module:: fairseq.models.transformer 56 | .. autoclass:: fairseq.models.transformer.TransformerModel 57 | :members: 58 | .. autoclass:: fairseq.models.transformer.TransformerEncoder 59 | :members: 60 | .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer 61 | :members: 62 | .. autoclass:: fairseq.models.transformer.TransformerDecoder 63 | :members: 64 | .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer 65 | :members: 66 | 67 | 68 | Adding new models 69 | ----------------- 70 | 71 | .. currentmodule:: fairseq.models 72 | .. autofunction:: fairseq.models.register_model 73 | .. autofunction:: fairseq.models.register_model_architecture 74 | .. autoclass:: fairseq.models.BaseFairseqModel 75 | :members: 76 | :undoc-members: 77 | .. autoclass:: fairseq.models.FairseqModel 78 | :members: 79 | :undoc-members: 80 | .. autoclass:: fairseq.models.FairseqLanguageModel 81 | :members: 82 | :undoc-members: 83 | .. autoclass:: fairseq.models.FairseqEncoder 84 | :members: 85 | .. autoclass:: fairseq.models.CompositeEncoder 86 | :members: 87 | .. autoclass:: fairseq.models.FairseqDecoder 88 | :members: 89 | 90 | 91 | .. _Incremental decoding: 92 | 93 | Incremental decoding 94 | -------------------- 95 | 96 | .. autoclass:: fairseq.models.FairseqIncrementalDecoder 97 | :members: 98 | :undoc-members: 99 | -------------------------------------------------------------------------------- /encdec/docs/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be 5 | helpful when implementing a new :class:`FairseqModel`. 6 | 7 | .. automodule:: fairseq.modules 8 | :members: 9 | :undoc-members: 10 | -------------------------------------------------------------------------------- /encdec/docs/optim.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _optimizers: 5 | 6 | Optimizers 7 | ========== 8 | 9 | .. automodule:: fairseq.optim 10 | :members: 11 | -------------------------------------------------------------------------------- /encdec/docs/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ======== 3 | 4 | Fairseq can be extended through user-supplied `plug-ins 5 | `_. We support five kinds of 6 | plug-ins: 7 | 8 | - :ref:`Models` define the neural network architecture and encapsulate all of the 9 | learnable parameters. 10 | - :ref:`Criterions` compute the loss function given the model outputs and targets. 11 | - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | - :ref:`Optimizers` update the Model parameters based on the gradients. 14 | - :ref:`Learning Rate Schedulers` update the learning rate over the course of 15 | training. 16 | 17 | **Training Flow** 18 | 19 | Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``, 20 | fairseq implements the following high-level training flow:: 21 | 22 | for epoch in range(num_epochs): 23 | itr = task.get_batch_iterator(task.dataset('train')) 24 | for num_updates, batch in enumerate(itr): 25 | loss = criterion(model, batch) 26 | optimizer.backward(loss) 27 | optimizer.step() 28 | lr_scheduler.step_update(num_updates) 29 | lr_scheduler.step(epoch) 30 | 31 | **Registering new plug-ins** 32 | 33 | New plug-ins are *registered* through a set of ``@register`` function 34 | decorators, for example:: 35 | 36 | @register_model('my_lstm') 37 | class MyLSTM(FairseqModel): 38 | (...) 39 | 40 | Once registered, new plug-ins can be used with the existing :ref:`Command-line 41 | Tools`. See the Tutorial sections for more detailed walkthroughs of how to add 42 | new plug-ins. 43 | -------------------------------------------------------------------------------- /encdec/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<2.0 2 | sphinx-argparse 3 | -------------------------------------------------------------------------------- /encdec/docs/tasks.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.tasks 5 | 6 | .. _Tasks: 7 | 8 | Tasks 9 | ===== 10 | 11 | Tasks store dictionaries and provide helpers for loading/iterating over 12 | Datasets, initializing the Model/Criterion and calculating the loss. 13 | 14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a 15 | task may expose additional command-line arguments for further configuration. 16 | 17 | Example usage:: 18 | 19 | # setup the task (e.g., load dictionaries) 20 | task = fairseq.tasks.setup_task(args) 21 | 22 | # build model and criterion 23 | model = task.build_model(args) 24 | criterion = task.build_criterion(args) 25 | 26 | # load datasets 27 | task.load_dataset('train') 28 | task.load_dataset('valid') 29 | 30 | # iterate over mini-batches of data 31 | batch_itr = task.get_batch_iterator( 32 | task.dataset('train'), max_tokens=4096, 33 | ) 34 | for batch in batch_itr: 35 | # compute the loss 36 | loss, sample_size, logging_output = task.get_loss( 37 | model, criterion, batch, 38 | ) 39 | loss.backward() 40 | 41 | 42 | Translation 43 | ----------- 44 | 45 | .. autoclass:: fairseq.tasks.translation.TranslationTask 46 | 47 | .. _language modeling: 48 | 49 | Language Modeling 50 | ----------------- 51 | 52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask 53 | 54 | 55 | Adding new tasks 56 | ---------------- 57 | 58 | .. autofunction:: fairseq.tasks.register_task 59 | .. autoclass:: fairseq.tasks.FairseqTask 60 | :members: 61 | :undoc-members: 62 | -------------------------------------------------------------------------------- /encdec/examples/.gitignore: -------------------------------------------------------------------------------- 1 | */* 2 | !*/*.sh 3 | !*/*.md 4 | -------------------------------------------------------------------------------- /encdec/examples/language_model/README.md: -------------------------------------------------------------------------------- 1 | Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit 2 | 3 | These scripts provide an example of pre-processing data for the Language Modeling task. 4 | 5 | # prepare-wikitext-103.sh 6 | 7 | Provides an example of pre-processing for [WikiText-103 language modeling task](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset): 8 | 9 | Example usage: 10 | ``` 11 | $ cd examples/language_model/ 12 | $ bash prepare-wikitext-103.sh 13 | $ cd ../.. 14 | 15 | # Binarize the dataset: 16 | $ TEXT=examples/language_model/wikitext-103 17 | 18 | $ python preprocess.py --only-source \ 19 | --trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \ 20 | --destdir data-bin/wikitext-103 21 | 22 | # Train the model: 23 | # If it runs out of memory, try to reduce max-tokens and max-target-positions 24 | $ mkdir -p checkpoints/wikitext-103 25 | $ python train.py --task language_modeling data-bin/wikitext-103 \ 26 | --max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \ 27 | --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ 28 | --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \ 29 | --adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024 30 | 31 | # Evaluate: 32 | $ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt' 33 | 34 | ``` 35 | -------------------------------------------------------------------------------- /encdec/examples/language_model/prepare-wikitext-103.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | URLS=( 5 | "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" 6 | ) 7 | FILES=( 8 | "wikitext-103-v1.zip" 9 | ) 10 | 11 | for ((i=0;i<${#URLS[@]};++i)); do 12 | file=${FILES[i]} 13 | if [ -f $file ]; then 14 | echo "$file already exists, skipping download" 15 | else 16 | url=${URLS[i]} 17 | wget "$url" 18 | if [ -f $file ]; then 19 | echo "$url successfully downloaded." 20 | else 21 | echo "$url not successfully downloaded." 22 | exit -1 23 | fi 24 | if [ ${file: -4} == ".tgz" ]; then 25 | tar zxvf $file 26 | elif [ ${file: -4} == ".tar" ]; then 27 | tar xvf $file 28 | elif [ ${file: -4} == ".zip" ]; then 29 | unzip $file 30 | fi 31 | fi 32 | done 33 | cd .. 34 | -------------------------------------------------------------------------------- /encdec/examples/stories/README.md: -------------------------------------------------------------------------------- 1 | FAIR Sequence-to-Sequence Toolkit for Story Generation 2 | 3 | The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset. 4 | 5 | The dataset can be downloaded like this: 6 | 7 | ``` 8 | curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvzf - 9 | ``` 10 | 11 | and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token. 12 | 13 | 14 | Example usage: 15 | ``` 16 | # Preprocess the dataset: 17 | # Note that the dataset release is the full data, but the paper models the first 1000 words of each story 18 | # Here is some example code that can trim the dataset to the first 1000 words of each story 19 | $ python 20 | $ data = ["train", "test", "valid"] 21 | $ for name in data: 22 | $ with open(name + ".wp_target") as f: 23 | $ stories = f.readlines() 24 | $ stories = [" ".join(i.split()[0:1000]) for i in stories] 25 | $ with open(name + ".wp_target", "w") as o: 26 | $ for line in stories: 27 | $ o.write(line.strip() + "\n") 28 | 29 | # Binarize the dataset: 30 | $ TEXT=examples/stories/writingPrompts 31 | $ python preprocess.py --source-lang wp_source --target-lang wp_target \ 32 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 33 | --destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10 34 | 35 | # Train the model: 36 | $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False 37 | 38 | # Train a fusion model: 39 | # add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint 40 | 41 | # Generate: 42 | # Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary. 43 | 44 | $ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}" 45 | ``` 46 | -------------------------------------------------------------------------------- /encdec/examples/translation/prepare-iwslt14.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 4 | 5 | echo 'Cloning Moses github repository (for tokenization scripts)...' 6 | git clone https://github.com/moses-smt/mosesdecoder.git 7 | 8 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 9 | git clone https://github.com/rsennrich/subword-nmt.git 10 | 11 | SCRIPTS=mosesdecoder/scripts 12 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 13 | LC=$SCRIPTS/tokenizer/lowercase.perl 14 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=10000 17 | 18 | URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" 19 | GZ=de-en.tgz 20 | 21 | if [ ! -d "$SCRIPTS" ]; then 22 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 23 | exit 24 | fi 25 | 26 | src=de 27 | tgt=en 28 | lang=de-en 29 | prep=iwslt14.tokenized.de-en 30 | tmp=$prep/tmp 31 | orig=orig 32 | 33 | mkdir -p $orig $tmp $prep 34 | 35 | echo "Downloading data from ${URL}..." 36 | cd $orig 37 | wget "$URL" 38 | 39 | if [ -f $GZ ]; then 40 | echo "Data successfully downloaded." 41 | else 42 | echo "Data not successfully downloaded." 43 | exit 44 | fi 45 | 46 | tar zxvf $GZ 47 | cd .. 48 | 49 | echo "pre-processing train data..." 50 | for l in $src $tgt; do 51 | f=train.tags.$lang.$l 52 | tok=train.tags.$lang.tok.$l 53 | 54 | cat $orig/$lang/$f | \ 55 | grep -v '' | \ 56 | grep -v '' | \ 57 | grep -v '' | \ 58 | sed -e 's///g' | \ 59 | sed -e 's/<\/title>//g' | \ 60 | sed -e 's/<description>//g' | \ 61 | sed -e 's/<\/description>//g' | \ 62 | perl $TOKENIZER -threads 8 -l $l > $tmp/$tok 63 | echo "" 64 | done 65 | perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175 66 | for l in $src $tgt; do 67 | perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l 68 | done 69 | 70 | echo "pre-processing valid/test data..." 71 | for l in $src $tgt; do 72 | for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do 73 | fname=${o##*/} 74 | f=$tmp/${fname%.*} 75 | echo $o $f 76 | grep '<seg id' $o | \ 77 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 78 | sed -e 's/\s*<\/seg>\s*//g' | \ 79 | sed -e "s/\’/\'/g" | \ 80 | perl $TOKENIZER -threads 8 -l $l | \ 81 | perl $LC > $f 82 | echo "" 83 | done 84 | done 85 | 86 | 87 | echo "creating train, valid, test..." 88 | for l in $src $tgt; do 89 | awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l 90 | awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l 91 | 92 | cat $tmp/IWSLT14.TED.dev2010.de-en.$l \ 93 | $tmp/IWSLT14.TEDX.dev2012.de-en.$l \ 94 | $tmp/IWSLT14.TED.tst2010.de-en.$l \ 95 | $tmp/IWSLT14.TED.tst2011.de-en.$l \ 96 | $tmp/IWSLT14.TED.tst2012.de-en.$l \ 97 | > $tmp/test.$l 98 | done 99 | 100 | TRAIN=$tmp/train.en-de 101 | BPE_CODE=$prep/code 102 | rm -f $TRAIN 103 | for l in $src $tgt; do 104 | cat $tmp/train.$l >> $TRAIN 105 | done 106 | 107 | echo "learn_bpe.py on ${TRAIN}..." 108 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 109 | 110 | for L in $src $tgt; do 111 | for f in train.$L valid.$L test.$L; do 112 | echo "apply_bpe.py to ${f}..." 113 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f 114 | done 115 | done 116 | -------------------------------------------------------------------------------- /encdec/examples/translation/prepare-wmt14en2de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | echo 'Cloning Moses github repository (for tokenization scripts)...' 5 | git clone https://github.com/moses-smt/mosesdecoder.git 6 | 7 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 8 | git clone https://github.com/rsennrich/subword-nmt.git 9 | 10 | SCRIPTS=mosesdecoder/scripts 11 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=40000 17 | 18 | URLS=( 19 | "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" 20 | "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" 21 | "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" 22 | "http://data.statmt.org/wmt17/translation-task/dev.tgz" 23 | "http://statmt.org/wmt14/test-full.tgz" 24 | ) 25 | FILES=( 26 | "training-parallel-europarl-v7.tgz" 27 | "training-parallel-commoncrawl.tgz" 28 | "training-parallel-nc-v12.tgz" 29 | "dev.tgz" 30 | "test-full.tgz" 31 | ) 32 | CORPORA=( 33 | "training/europarl-v7.de-en" 34 | "commoncrawl.de-en" 35 | "training/news-commentary-v12.de-en" 36 | ) 37 | 38 | # This will make the dataset compatible to the one used in "Convolutional Sequence to Sequence Learning" 39 | # https://arxiv.org/abs/1705.03122 40 | if [ "$1" == "--icml17" ]; then 41 | URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz" 42 | FILES[2]="training-parallel-nc-v9.tgz" 43 | CORPORA[2]="training/news-commentary-v9.de-en" 44 | fi 45 | 46 | if [ ! -d "$SCRIPTS" ]; then 47 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 48 | exit 49 | fi 50 | 51 | src=en 52 | tgt=de 53 | lang=en-de 54 | prep=wmt14_en_de 55 | tmp=$prep/tmp 56 | orig=orig 57 | dev=dev/newstest2013 58 | 59 | mkdir -p $orig $tmp $prep 60 | 61 | cd $orig 62 | 63 | for ((i=0;i<${#URLS[@]};++i)); do 64 | file=${FILES[i]} 65 | if [ -f $file ]; then 66 | echo "$file already exists, skipping download" 67 | else 68 | url=${URLS[i]} 69 | wget "$url" 70 | if [ -f $file ]; then 71 | echo "$url successfully downloaded." 72 | else 73 | echo "$url not successfully downloaded." 74 | exit -1 75 | fi 76 | if [ ${file: -4} == ".tgz" ]; then 77 | tar zxvf $file 78 | elif [ ${file: -4} == ".tar" ]; then 79 | tar xvf $file 80 | fi 81 | fi 82 | done 83 | cd .. 84 | 85 | echo "pre-processing train data..." 86 | for l in $src $tgt; do 87 | rm $tmp/train.tags.$lang.tok.$l 88 | for f in "${CORPORA[@]}"; do 89 | cat $orig/$f.$l | \ 90 | perl $NORM_PUNC $l | \ 91 | perl $REM_NON_PRINT_CHAR | \ 92 | perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l 93 | done 94 | done 95 | 96 | echo "pre-processing test data..." 97 | for l in $src $tgt; do 98 | if [ "$l" == "$src" ]; then 99 | t="src" 100 | else 101 | t="ref" 102 | fi 103 | grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \ 104 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 105 | sed -e 's/\s*<\/seg>\s*//g' | \ 106 | sed -e "s/\’/\'/g" | \ 107 | perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l 108 | echo "" 109 | done 110 | 111 | echo "splitting train and valid..." 112 | for l in $src $tgt; do 113 | awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l 114 | awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l 115 | done 116 | 117 | TRAIN=$tmp/train.de-en 118 | BPE_CODE=$prep/code 119 | rm -f $TRAIN 120 | for l in $src $tgt; do 121 | cat $tmp/train.$l >> $TRAIN 122 | done 123 | 124 | echo "learn_bpe.py on ${TRAIN}..." 125 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 126 | 127 | for L in $src $tgt; do 128 | for f in train.$L valid.$L test.$L; do 129 | echo "apply_bpe.py to ${f}..." 130 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f 131 | done 132 | done 133 | 134 | perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 135 | perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 136 | 137 | for L in $src $tgt; do 138 | cp $tmp/bpe.test.$L $prep/test.$L 139 | done 140 | -------------------------------------------------------------------------------- /encdec/examples/translation/prepare-wmt14en2fr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | echo 'Cloning Moses github repository (for tokenization scripts)...' 5 | git clone https://github.com/moses-smt/mosesdecoder.git 6 | 7 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 8 | git clone https://github.com/rsennrich/subword-nmt.git 9 | 10 | SCRIPTS=mosesdecoder/scripts 11 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=40000 17 | 18 | URLS=( 19 | "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" 20 | "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" 21 | "http://statmt.org/wmt13/training-parallel-un.tgz" 22 | "http://statmt.org/wmt14/training-parallel-nc-v9.tgz" 23 | "http://statmt.org/wmt10/training-giga-fren.tar" 24 | "http://statmt.org/wmt14/test-full.tgz" 25 | ) 26 | FILES=( 27 | "training-parallel-europarl-v7.tgz" 28 | "training-parallel-commoncrawl.tgz" 29 | "training-parallel-un.tgz" 30 | "training-parallel-nc-v9.tgz" 31 | "training-giga-fren.tar" 32 | "test-full.tgz" 33 | ) 34 | CORPORA=( 35 | "training/europarl-v7.fr-en" 36 | "commoncrawl.fr-en" 37 | "un/undoc.2000.fr-en" 38 | "training/news-commentary-v9.fr-en" 39 | "giga-fren.release2.fixed" 40 | ) 41 | 42 | if [ ! -d "$SCRIPTS" ]; then 43 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 44 | exit 45 | fi 46 | 47 | src=en 48 | tgt=fr 49 | lang=en-fr 50 | prep=wmt14_en_fr 51 | tmp=$prep/tmp 52 | orig=orig 53 | 54 | mkdir -p $orig $tmp $prep 55 | 56 | cd $orig 57 | 58 | for ((i=0;i<${#URLS[@]};++i)); do 59 | file=${FILES[i]} 60 | if [ -f $file ]; then 61 | echo "$file already exists, skipping download" 62 | else 63 | url=${URLS[i]} 64 | wget "$url" 65 | if [ -f $file ]; then 66 | echo "$url successfully downloaded." 67 | else 68 | echo "$url not successfully downloaded." 69 | exit -1 70 | fi 71 | if [ ${file: -4} == ".tgz" ]; then 72 | tar zxvf $file 73 | elif [ ${file: -4} == ".tar" ]; then 74 | tar xvf $file 75 | fi 76 | fi 77 | done 78 | 79 | gunzip giga-fren.release2.fixed.*.gz 80 | cd .. 81 | 82 | echo "pre-processing train data..." 83 | for l in $src $tgt; do 84 | rm $tmp/train.tags.$lang.tok.$l 85 | for f in "${CORPORA[@]}"; do 86 | cat $orig/$f.$l | \ 87 | perl $NORM_PUNC $l | \ 88 | perl $REM_NON_PRINT_CHAR | \ 89 | perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l 90 | done 91 | done 92 | 93 | echo "pre-processing test data..." 94 | for l in $src $tgt; do 95 | if [ "$l" == "$src" ]; then 96 | t="src" 97 | else 98 | t="ref" 99 | fi 100 | grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \ 101 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 102 | sed -e 's/\s*<\/seg>\s*//g' | \ 103 | sed -e "s/\’/\'/g" | \ 104 | perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l 105 | echo "" 106 | done 107 | 108 | echo "splitting train and valid..." 109 | for l in $src $tgt; do 110 | awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l 111 | awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l 112 | done 113 | 114 | TRAIN=$tmp/train.fr-en 115 | BPE_CODE=$prep/code 116 | rm -f $TRAIN 117 | for l in $src $tgt; do 118 | cat $tmp/train.$l >> $TRAIN 119 | done 120 | 121 | echo "learn_bpe.py on ${TRAIN}..." 122 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 123 | 124 | for L in $src $tgt; do 125 | for f in train.$L valid.$L test.$L; do 126 | echo "apply_bpe.py to ${f}..." 127 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f 128 | done 129 | done 130 | 131 | perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 132 | perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 133 | 134 | for L in $src $tgt; do 135 | cp $tmp/bpe.test.$L $prep/test.$L 136 | done 137 | -------------------------------------------------------------------------------- /encdec/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .multiprocessing_pdb import pdb 9 | 10 | __all__ = ['pdb'] 11 | 12 | import fairseq.criterions 13 | import fairseq.models 14 | import fairseq.modules 15 | import fairseq.optim 16 | import fairseq.optim.lr_scheduler 17 | import fairseq.tasks 18 | -------------------------------------------------------------------------------- /encdec/fairseq/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import ctypes 9 | import math 10 | import torch 11 | 12 | try: 13 | from fairseq import libbleu 14 | except ImportError as e: 15 | import sys 16 | sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py install`\n') 17 | raise e 18 | 19 | 20 | C = ctypes.cdll.LoadLibrary(libbleu.__file__) 21 | 22 | 23 | class BleuStat(ctypes.Structure): 24 | _fields_ = [ 25 | ('reflen', ctypes.c_size_t), 26 | ('predlen', ctypes.c_size_t), 27 | ('match1', ctypes.c_size_t), 28 | ('count1', ctypes.c_size_t), 29 | ('match2', ctypes.c_size_t), 30 | ('count2', ctypes.c_size_t), 31 | ('match3', ctypes.c_size_t), 32 | ('count3', ctypes.c_size_t), 33 | ('match4', ctypes.c_size_t), 34 | ('count4', ctypes.c_size_t), 35 | ] 36 | 37 | 38 | class Scorer(object): 39 | def __init__(self, pad, eos, unk): 40 | self.stat = BleuStat() 41 | self.pad = pad 42 | self.eos = eos 43 | self.unk = unk 44 | self.reset() 45 | 46 | def reset(self, one_init=False): 47 | if one_init: 48 | C.bleu_one_init(ctypes.byref(self.stat)) 49 | else: 50 | C.bleu_zero_init(ctypes.byref(self.stat)) 51 | 52 | def add(self, ref, pred): 53 | if not isinstance(ref, torch.IntTensor): 54 | raise TypeError('ref must be a torch.IntTensor (got {})' 55 | .format(type(ref))) 56 | if not isinstance(pred, torch.IntTensor): 57 | raise TypeError('pred must be a torch.IntTensor(got {})' 58 | .format(type(pred))) 59 | 60 | # don't match unknown words 61 | rref = ref.clone() 62 | assert not rref.lt(0).any() 63 | rref[rref.eq(self.unk)] = -999 64 | 65 | rref = rref.contiguous().view(-1) 66 | pred = pred.contiguous().view(-1) 67 | 68 | C.bleu_add( 69 | ctypes.byref(self.stat), 70 | ctypes.c_size_t(rref.size(0)), 71 | ctypes.c_void_p(rref.data_ptr()), 72 | ctypes.c_size_t(pred.size(0)), 73 | ctypes.c_void_p(pred.data_ptr()), 74 | ctypes.c_int(self.pad), 75 | ctypes.c_int(self.eos)) 76 | 77 | def score(self, order=4): 78 | psum = sum(math.log(p) if p > 0 else float('-Inf') 79 | for p in self.precision()[:order]) 80 | return self.brevity() * math.exp(psum / order) * 100 81 | 82 | def precision(self): 83 | def ratio(a, b): 84 | return a / b if b > 0 else 0 85 | 86 | return [ 87 | ratio(self.stat.match1, self.stat.count1), 88 | ratio(self.stat.match2, self.stat.count2), 89 | ratio(self.stat.match3, self.stat.count3), 90 | ratio(self.stat.match4, self.stat.count4), 91 | ] 92 | 93 | def brevity(self): 94 | r = self.stat.reflen / self.stat.predlen 95 | return min(1, math.exp(1 - r)) 96 | 97 | def result_string(self, order=4): 98 | assert order <= 4, "BLEU scores for order > 4 aren't supported" 99 | fmt = 'BLEU{} = {:2.2f}, {:2.1f}' 100 | for _ in range(1, order): 101 | fmt += '/{:2.1f}' 102 | fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})' 103 | bleup = [p * 100 for p in self.precision()[:order]] 104 | return fmt.format(order, self.score(order=order), *bleup, 105 | self.brevity(), self.stat.predlen/self.stat.reflen, 106 | self.stat.predlen, self.stat.reflen) 107 | -------------------------------------------------------------------------------- /encdec/fairseq/clib/libbleu/libbleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <map> 10 | #include <array> 11 | #include <cstring> 12 | #include <cstdio> 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, size_t n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += predlen; 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map<size_t, size_t> count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 0; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 0; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /encdec/fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <Python.h> 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /encdec/fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_criterion import FairseqCriterion 12 | 13 | 14 | CRITERION_REGISTRY = {} 15 | CRITERION_CLASS_NAMES = set() 16 | 17 | 18 | def build_criterion(args, task): 19 | return CRITERION_REGISTRY[args.criterion](args, task) 20 | 21 | 22 | def register_criterion(name): 23 | """Decorator to register a new criterion.""" 24 | 25 | def register_criterion_cls(cls): 26 | if name in CRITERION_REGISTRY: 27 | raise ValueError('Cannot register duplicate criterion ({})'.format(name)) 28 | if not issubclass(cls, FairseqCriterion): 29 | raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__)) 30 | if cls.__name__ in CRITERION_CLASS_NAMES: 31 | # We use the criterion class name as a unique identifier in 32 | # checkpoints, so all criterions must have unique class names. 33 | raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__)) 34 | CRITERION_REGISTRY[name] = cls 35 | CRITERION_CLASS_NAMES.add(cls.__name__) 36 | return cls 37 | 38 | return register_criterion_cls 39 | 40 | 41 | # automatically import any Python files in the criterions/ directory 42 | for file in os.listdir(os.path.dirname(__file__)): 43 | if file.endswith('.py') and not file.startswith('_'): 44 | module = file[:file.find('.py')] 45 | importlib.import_module('fairseq.criterions.' + module) 46 | -------------------------------------------------------------------------------- /encdec/fairseq/criterions/adaptive_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import math 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('adaptive_loss') 17 | class AdaptiveLoss(FairseqCriterion): 18 | """This is an implementation of the loss function accompanying the adaptive softmax approximation for 19 | graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" 20 | (http://arxiv.org/abs/1609.04309).""" 21 | 22 | def __init__(self, args, task): 23 | super().__init__(args, task) 24 | 25 | if args.ddp_backend == 'c10d': 26 | raise Exception( 27 | 'AdaptiveLoss is not compatible with the c10d ' 28 | 'version of DistributedDataParallel. Please use ' 29 | '`--ddp-backend=no_c10d` instead.' 30 | ) 31 | 32 | def forward(self, model, sample, reduce=True): 33 | """Compute the loss for the given sample. 34 | 35 | Returns a tuple with three elements: 36 | 1) the loss 37 | 2) the sample size, which is used as the denominator for the gradient 38 | 3) logging outputs to display while training 39 | """ 40 | 41 | assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None 42 | adaptive_softmax = model.decoder.adaptive_softmax 43 | 44 | net_output = model(**sample['net_input']) 45 | orig_target = model.get_targets(sample, net_output) 46 | 47 | nsentences = orig_target.size(0) 48 | orig_target = orig_target.view(-1) 49 | 50 | bsz = orig_target.size(0) 51 | 52 | logits, target = adaptive_softmax(net_output[0], orig_target) 53 | assert len(target) == len(logits) 54 | 55 | loss = net_output[0].new(1 if reduce else bsz).zero_() 56 | 57 | for i in range(len(target)): 58 | if target[i] is not None: 59 | assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1)) 60 | loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx, 61 | reduce=reduce) 62 | 63 | orig = utils.strip_pad(orig_target, self.padding_idx) 64 | ntokens = orig.numel() 65 | sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens 66 | logging_output = { 67 | 'loss': utils.item(loss.data) if reduce else loss.data, 68 | 'ntokens': ntokens, 69 | 'nsentences': nsentences, 70 | 'sample_size': sample_size, 71 | } 72 | return loss, sample_size, logging_output 73 | 74 | @staticmethod 75 | def aggregate_logging_outputs(logging_outputs): 76 | """Aggregate logging outputs from data parallel training.""" 77 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 78 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 79 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 80 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 81 | agg_output = { 82 | 'loss': loss_sum / sample_size / math.log(2), 83 | 'nll_loss': loss_sum / sample_size / math.log(2), 84 | 'ntokens': ntokens, 85 | 'nsentences': nsentences, 86 | 'sample_size': sample_size, 87 | } 88 | if sample_size != ntokens: 89 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 90 | return agg_output 91 | -------------------------------------------------------------------------------- /encdec/fairseq/criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('cross_entropy') 17 | class CrossEntropyCriterion(FairseqCriterion): 18 | 19 | def __init__(self, args, task): 20 | super().__init__(args, task) 21 | 22 | def forward(self, model, sample, reduce=True): 23 | """Compute the loss for the given sample. 24 | 25 | Returns a tuple with three elements: 26 | 1) the loss 27 | 2) the sample size, which is used as the denominator for the gradient 28 | 3) logging outputs to display while training 29 | """ 30 | net_output = model(**sample['net_input']) 31 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 32 | lprobs = lprobs.view(-1, lprobs.size(-1)) 33 | target = model.get_targets(sample, net_output).view(-1) 34 | loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, 35 | reduce=reduce) 36 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 37 | logging_output = { 38 | 'loss': utils.item(loss.data) if reduce else loss.data, 39 | 'ntokens': sample['ntokens'], 40 | 'nsentences': sample['target'].size(0), 41 | 'sample_size': sample_size, 42 | } 43 | return loss, sample_size, logging_output 44 | 45 | @staticmethod 46 | def aggregate_logging_outputs(logging_outputs): 47 | """Aggregate logging outputs from data parallel training.""" 48 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 49 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 50 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 51 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 52 | agg_output = { 53 | 'loss': loss_sum / sample_size / math.log(2), 54 | 'ntokens': ntokens, 55 | 'nsentences': nsentences, 56 | 'sample_size': sample_size, 57 | } 58 | if sample_size != ntokens: 59 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 60 | return agg_output 61 | -------------------------------------------------------------------------------- /encdec/fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.nn.modules.loss import _Loss 9 | 10 | 11 | class FairseqCriterion(_Loss): 12 | 13 | def __init__(self, args, task): 14 | super().__init__() 15 | self.args = args 16 | self.padding_idx = task.target_dictionary.pad() 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add criterion-specific arguments to the parser.""" 21 | pass 22 | 23 | def forward(self, model, sample, reduce=True): 24 | """Compute the loss for the given sample. 25 | 26 | Returns a tuple with three elements: 27 | 1) the loss 28 | 2) the sample size, which is used as the denominator for the gradient 29 | 3) logging outputs to display while training 30 | """ 31 | raise NotImplementedError 32 | 33 | @staticmethod 34 | def aggregate_logging_outputs(logging_outputs): 35 | """Aggregate logging outputs from data parallel training.""" 36 | raise NotImplementedError 37 | 38 | @staticmethod 39 | def grad_denom(sample_sizes): 40 | """Compute the gradient denominator for a set of sample sizes.""" 41 | return sum(sample_sizes) 42 | -------------------------------------------------------------------------------- /encdec/fairseq/criterions/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from fairseq import utils 11 | 12 | from . import FairseqCriterion, register_criterion 13 | 14 | 15 | @register_criterion('label_smoothed_cross_entropy') 16 | class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): 17 | 18 | def __init__(self, args, task): 19 | super().__init__(args, task) 20 | self.eps = args.label_smoothing 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add criterion-specific arguments to the parser.""" 25 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 26 | help='epsilon for label smoothing, 0 means no label smoothing') 27 | 28 | def forward(self, model, sample, reduce=True): 29 | """Compute the loss for the given sample. 30 | 31 | Returns a tuple with three elements: 32 | 1) the loss 33 | 2) the sample size, which is used as the denominator for the gradient 34 | 3) logging outputs to display while training 35 | """ 36 | net_output = model(**sample['net_input']) 37 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 38 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 39 | logging_output = { 40 | 'loss': utils.item(loss.data) if reduce else loss.data, 41 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 42 | 'ntokens': sample['ntokens'], 43 | 'nsentences': sample['target'].size(0), 44 | 'sample_size': sample_size, 45 | } 46 | return loss, sample_size, logging_output 47 | 48 | def compute_loss(self, model, net_output, sample, reduce=True): 49 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 50 | lprobs = lprobs.view(-1, lprobs.size(-1)) 51 | target = model.get_targets(sample, net_output).view(-1, 1) 52 | non_pad_mask = target.ne(self.padding_idx) 53 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 54 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 55 | if reduce: 56 | nll_loss = nll_loss.sum() 57 | smooth_loss = smooth_loss.sum() 58 | eps_i = self.eps / lprobs.size(-1) 59 | loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss 60 | return loss, nll_loss 61 | 62 | @staticmethod 63 | def aggregate_logging_outputs(logging_outputs): 64 | """Aggregate logging outputs from data parallel training.""" 65 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 66 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 67 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 68 | return { 69 | 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 70 | 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 71 | 'ntokens': ntokens, 72 | 'nsentences': nsentences, 73 | 'sample_size': sample_size, 74 | } 75 | -------------------------------------------------------------------------------- /encdec/fairseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .dictionary import Dictionary, TruncatedDictionary 9 | from .fairseq_dataset import FairseqDataset 10 | from .concat_dataset import ConcatDataset 11 | from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset 12 | from .append_eos_dataset import AppendEosDataset 13 | from .language_pair_dataset import LanguagePairDataset 14 | from .monolingual_dataset import MonolingualDataset 15 | from .token_block_dataset import TokenBlockDataset 16 | 17 | from .iterators import ( 18 | CountingIterator, 19 | EpochBatchIterator, 20 | GroupedIterator, 21 | ShardedIterator, 22 | ) 23 | 24 | __all__ = [ 25 | 'AppendEosDataset', 26 | 'ConcatDataset', 27 | 'CountingIterator', 28 | 'Dictionary', 29 | 'EpochBatchIterator', 30 | 'FairseqDataset', 31 | 'GroupedIterator', 32 | 'IndexedCachedDataset', 33 | 'IndexedDataset', 34 | 'IndexedInMemoryDataset', 35 | 'IndexedRawTextDataset', 36 | 'LanguagePairDataset', 37 | 'MonolingualDataset', 38 | 'ShardedIterator', 39 | 'TokenBlockDataset', 40 | ] 41 | -------------------------------------------------------------------------------- /encdec/fairseq/data/append_eos_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class AppendEosDataset(torch.utils.data.Dataset): 12 | """A dataset wrapper that appends EOS to each item.""" 13 | 14 | def __init__(self, dataset, eos): 15 | self.dataset = dataset 16 | self.eos = eos 17 | 18 | def __getitem__(self, index): 19 | item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])]) 20 | print(item) 21 | return item 22 | 23 | def __len__(self): 24 | return len(self.dataset) 25 | -------------------------------------------------------------------------------- /encdec/fairseq/data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | 3 | from . import FairseqDataset 4 | 5 | 6 | class ConcatDataset(FairseqDataset): 7 | 8 | @staticmethod 9 | def cumsum(sequence): 10 | r, s = [], 0 11 | for e in sequence: 12 | l = len(e) 13 | r.append(l + s) 14 | s += l 15 | return r 16 | 17 | def __init__(self, datasets): 18 | super(ConcatDataset, self).__init__() 19 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 20 | self.datasets = list(datasets) 21 | self.cummulative_sizes = self.cumsum(self.datasets) 22 | 23 | def __len__(self): 24 | return self.cummulative_sizes[-1] 25 | 26 | def __getitem__(self, idx): 27 | dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) 28 | if dataset_idx == 0: 29 | sample_idx = idx 30 | else: 31 | sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] 32 | return self.datasets[dataset_idx][sample_idx] 33 | 34 | @property 35 | def supports_prefetch(self): 36 | return all([d.supports_prefetch for d in self.datasets]) 37 | 38 | def prefetch(self, indices): 39 | frm = 0 40 | for to, ds in zip(self.cummulative_sizes, self.datasets): 41 | ds.prefetch([i - frm for i in indices if frm <= i < to]) 42 | frm = to 43 | -------------------------------------------------------------------------------- /encdec/fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.utils.data 9 | 10 | from fairseq.data import data_utils 11 | 12 | 13 | class FairseqDataset(torch.utils.data.Dataset): 14 | """A dataset that provides helpers for batching.""" 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def collater(self, samples): 23 | """Merge a list of samples to form a mini-batch. 24 | 25 | Args: 26 | samples (List[int]): sample indices to collate 27 | 28 | Returns: 29 | dict: a mini-batch suitable for forwarding with a Model 30 | """ 31 | raise NotImplementedError 32 | 33 | def get_dummy_batch(self, num_tokens, max_positions): 34 | """Return a dummy batch with a given number of tokens.""" 35 | raise NotImplementedError 36 | 37 | def num_tokens(self, index): 38 | """Return the number of tokens in a sample. This value is used to 39 | enforce ``--max-tokens`` during batching.""" 40 | raise NotImplementedError 41 | 42 | def size(self, index): 43 | """Return an example's size as a float or tuple. This value is used when 44 | filtering a dataset with ``--max-positions``.""" 45 | raise NotImplementedError 46 | 47 | def ordered_indices(self): 48 | """Return an ordered list of indices. Batches will be constructed based 49 | on this order.""" 50 | raise NotImplementedError 51 | 52 | @property 53 | def supports_prefetch(self): 54 | return False 55 | 56 | def prefetch(self, indices): 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /encdec/fairseq/data/token_block_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | class TokenBlockDataset(torch.utils.data.Dataset): 15 | """Break a 1d tensor of tokens into blocks. 16 | 17 | The blocks are fetched from the original tensor so no additional memory is allocated. 18 | 19 | Args: 20 | tokens: 1d tensor of tokens to break into blocks 21 | sizes: sentence lengths (required for 'complete' and 'eos') 22 | block_size: maximum block size (ignored in 'eos' break mode) 23 | break_mode: Mode used for breaking tokens. Values can be one of: 24 | - 'none': break tokens into equally sized blocks (up to block_size) 25 | - 'complete': break tokens into blocks (up to block_size) such that 26 | blocks contains complete sentences, although block_size may be 27 | exceeded if some sentences exceed block_size 28 | - 'eos': each block contains one sentence (block_size is ignored) 29 | include_targets: return next tokens as targets 30 | """ 31 | 32 | def __init__(self, tokens, sizes, block_size, pad, eos, break_mode=None, include_targets=False): 33 | super().__init__() 34 | 35 | self.tokens = tokens 36 | self.total_size = len(tokens) 37 | self.pad = pad 38 | self.eos = eos 39 | self.include_targets = include_targets 40 | self.slice_indices = [] 41 | 42 | if break_mode is None or break_mode == 'none': 43 | length = math.ceil(len(tokens) / block_size) 44 | 45 | def block_at(i): 46 | start = i * block_size 47 | end = min(start + block_size, len(tokens)) 48 | return (start, end) 49 | 50 | self.slice_indices = [block_at(i) for i in range(length)] 51 | elif break_mode == 'complete': 52 | assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens)) 53 | tok_idx = 0 54 | sz_idx = 0 55 | curr_size = 0 56 | while sz_idx < len(sizes): 57 | if curr_size + sizes[sz_idx] <= block_size or curr_size == 0: 58 | curr_size += sizes[sz_idx] 59 | sz_idx += 1 60 | else: 61 | self.slice_indices.append((tok_idx, tok_idx + curr_size)) 62 | tok_idx += curr_size 63 | curr_size = 0 64 | if curr_size > 0: 65 | self.slice_indices.append((tok_idx, tok_idx + curr_size)) 66 | elif break_mode == 'eos': 67 | assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens)) 68 | curr = 0 69 | for sz in sizes: 70 | # skip samples with just 1 example (which would be just the eos token) 71 | if sz > 1: 72 | self.slice_indices.append((curr, curr + sz)) 73 | curr += sz 74 | else: 75 | raise ValueError('Invalid break_mode: ' + break_mode) 76 | 77 | self.sizes = np.array([e - s for s, e in self.slice_indices]) 78 | 79 | def __getitem__(self, index): 80 | s, e = self.slice_indices[index] 81 | 82 | item = torch.LongTensor(self.tokens[s:e]) 83 | 84 | if self.include_targets: 85 | # target is the sentence, for source, rotate item one token to the left (would start with eos) 86 | # past target is rotated to the left by 2 (padded if its first) 87 | if s == 0: 88 | source = np.concatenate([[self.eos], self.tokens[0:e - 1]]) 89 | past_target = np.concatenate([[self.pad, self.eos], self.tokens[0:e - 2]]) 90 | else: 91 | source = self.tokens[s - 1:e - 1] 92 | if s == 1: 93 | past_target = np.concatenate([[self.eos], self.tokens[0:e - 2]]) 94 | else: 95 | past_target = self.tokens[s - 2:e - 2] 96 | 97 | return torch.LongTensor(source), item, torch.LongTensor(past_target) 98 | return item 99 | 100 | def __len__(self): 101 | return len(self.slice_indices) 102 | -------------------------------------------------------------------------------- /encdec/fairseq/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import time 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | class TimeMeter(object): 30 | """Computes the average occurrence of some event per second""" 31 | def __init__(self, init=0): 32 | self.reset(init) 33 | 34 | def reset(self, init=0): 35 | self.init = init 36 | self.start = time.time() 37 | self.n = 0 38 | 39 | def update(self, val=1): 40 | self.n += val 41 | 42 | @property 43 | def avg(self): 44 | return self.n / self.elapsed_time 45 | 46 | @property 47 | def elapsed_time(self): 48 | return self.init + (time.time() - self.start) 49 | 50 | 51 | class StopwatchMeter(object): 52 | """Computes the sum/avg duration of some event in seconds""" 53 | def __init__(self): 54 | self.reset() 55 | 56 | def start(self): 57 | self.start_time = time.time() 58 | 59 | def stop(self, n=1): 60 | if self.start_time is not None: 61 | delta = time.time() - self.start_time 62 | self.sum += delta 63 | self.n += n 64 | self.start_time = None 65 | 66 | def reset(self): 67 | self.sum = 0 68 | self.n = 0 69 | self.start_time = None 70 | 71 | @property 72 | def avg(self): 73 | return self.sum / self.n 74 | -------------------------------------------------------------------------------- /encdec/fairseq/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import importlib 10 | import os 11 | 12 | from .fairseq_decoder import FairseqDecoder # noqa: F401 13 | from .fairseq_encoder import FairseqEncoder # noqa: F401 14 | from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 15 | from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 16 | 17 | from .composite_encoder import CompositeEncoder # noqa: F401 18 | from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401 19 | 20 | 21 | MODEL_REGISTRY = {} 22 | ARCH_MODEL_REGISTRY = {} 23 | ARCH_MODEL_INV_REGISTRY = {} 24 | ARCH_CONFIG_REGISTRY = {} 25 | 26 | 27 | def build_model(args, task): 28 | return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) 29 | 30 | 31 | def register_model(name): 32 | """ 33 | New model types can be added to fairseq with the :func:`register_model` 34 | function decorator. 35 | 36 | For example:: 37 | 38 | @register_model('lstm') 39 | class LSTM(FairseqModel): 40 | (...) 41 | 42 | .. note:: All models must implement the :class:`BaseFairseqModel` interface. 43 | Typically you will extend :class:`FairseqModel` for sequence-to-sequence 44 | tasks or :class:`FairseqLanguageModel` for language modeling tasks. 45 | 46 | Args: 47 | name (str): the name of the model 48 | """ 49 | 50 | def register_model_cls(cls): 51 | if name in MODEL_REGISTRY: 52 | raise ValueError('Cannot register duplicate model ({})'.format(name)) 53 | if not issubclass(cls, BaseFairseqModel): 54 | raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__)) 55 | MODEL_REGISTRY[name] = cls 56 | return cls 57 | 58 | return register_model_cls 59 | 60 | 61 | def register_model_architecture(model_name, arch_name): 62 | """ 63 | New model architectures can be added to fairseq with the 64 | :func:`register_model_architecture` function decorator. After registration, 65 | model architectures can be selected with the ``--arch`` command-line 66 | argument. 67 | 68 | For example:: 69 | 70 | @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') 71 | def lstm_luong_wmt_en_de(args): 72 | args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) 73 | (...) 74 | 75 | The decorated function should take a single argument *args*, which is a 76 | :class:`argparse.Namespace` of arguments parsed from the command-line. The 77 | decorated function should modify these arguments in-place to match the 78 | desired architecture. 79 | 80 | Args: 81 | model_name (str): the name of the Model (Model must already be 82 | registered) 83 | arch_name (str): the name of the model architecture (``--arch``) 84 | """ 85 | 86 | def register_model_arch_fn(fn): 87 | if model_name not in MODEL_REGISTRY: 88 | raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name)) 89 | if arch_name in ARCH_MODEL_REGISTRY: 90 | raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name)) 91 | if not callable(fn): 92 | raise ValueError('Model architecture must be callable ({})'.format(arch_name)) 93 | ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] 94 | ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) 95 | ARCH_CONFIG_REGISTRY[arch_name] = fn 96 | return fn 97 | 98 | return register_model_arch_fn 99 | 100 | 101 | # automatically import any Python files in the models/ directory 102 | for file in os.listdir(os.path.dirname(__file__)): 103 | if file.endswith('.py') and not file.startswith('_'): 104 | model_name = file[:file.find('.py')] 105 | module = importlib.import_module('fairseq.models.' + model_name) 106 | 107 | # extra `model_parser` for sphinx 108 | if model_name in MODEL_REGISTRY: 109 | parser = argparse.ArgumentParser(add_help=False) 110 | group_archs = parser.add_argument_group('Named architectures') 111 | group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name]) 112 | group_args = parser.add_argument_group('Additional command-line arguments') 113 | MODEL_REGISTRY[model_name].add_args(group_args) 114 | globals()[model_name + '_parser'] = parser 115 | -------------------------------------------------------------------------------- /encdec/fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqEncoder 9 | 10 | 11 | class CompositeEncoder(FairseqEncoder): 12 | """ 13 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 14 | 15 | We run forward on each encoder and return a dictionary of outputs. The first 16 | encoder's dictionary is used for initialization. 17 | 18 | Args: 19 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 20 | """ 21 | 22 | def __init__(self, encoders): 23 | super().__init__(next(iter(encoders.values())).dictionary) 24 | self.encoders = encoders 25 | for key in self.encoders: 26 | self.add_module(key, self.encoders[key]) 27 | 28 | def forward(self, src_tokens, src_lengths): 29 | """ 30 | Args: 31 | src_tokens (LongTensor): tokens in the source language of shape 32 | `(batch, src_len)` 33 | src_lengths (LongTensor): lengths of each source sentence of shape 34 | `(batch)` 35 | 36 | Returns: 37 | dict: 38 | the outputs from each Encoder 39 | """ 40 | encoder_out = {} 41 | for key in self.encoders: 42 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 43 | return encoder_out 44 | 45 | def reorder_encoder_out(self, encoder_out, new_order): 46 | """Reorder encoder output according to new_order.""" 47 | for key in self.encoders: 48 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) 49 | return encoder_out 50 | 51 | def max_positions(self): 52 | return min([self.encoders[key].max_positions() for key in self.encoders]) 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | for key in self.encoders: 56 | self.encoders[key].upgrade_state_dict(state_dict) 57 | return state_dict 58 | -------------------------------------------------------------------------------- /encdec/fairseq/models/distributed_fairseq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.nn import parallel 9 | 10 | from fairseq.distributed_utils import c10d_status 11 | 12 | from . import BaseFairseqModel 13 | 14 | 15 | def DistributedFairseqModel(args, model): 16 | """ 17 | Wrap a *model* to support distributed data parallel training. 18 | 19 | This is similar to the built-in DistributedDataParallel, but allows 20 | additional configuration of the DistributedDataParallel class to 21 | use, and also provides easier access to the wrapped model by 22 | forwarding requests for missing attributes to the wrapped model. 23 | 24 | Args: 25 | args (argparse.Namespace): fairseq args 26 | model (BaseFairseqModel): model to wrap 27 | """ 28 | 29 | # determine which DDP class to extend 30 | assert isinstance(model, BaseFairseqModel) 31 | if args.ddp_backend == 'c10d': 32 | if c10d_status.is_default: 33 | ddp_class = parallel.DistributedDataParallel 34 | elif c10d_status.has_c10d: 35 | ddp_class = parallel._DistributedDataParallelC10d 36 | else: 37 | raise Exception( 38 | 'Can\'t find c10d version of DistributedDataParallel. ' 39 | 'Please update PyTorch.' 40 | ) 41 | init_kwargs = dict( 42 | module=model, 43 | device_ids=[args.device_id], 44 | output_device=args.device_id, 45 | broadcast_buffers=False, 46 | bucket_cap_mb=args.bucket_cap_mb, 47 | ) 48 | elif args.ddp_backend == 'no_c10d': 49 | if c10d_status.is_default: 50 | ddp_class = parallel.deprecated.DistributedDataParallel 51 | else: 52 | ddp_class = parallel.DistributedDataParallel 53 | init_kwargs = dict( 54 | module=model, 55 | device_ids=[args.device_id], 56 | output_device=args.device_id, 57 | broadcast_buffers=False, 58 | ) 59 | else: 60 | raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) 61 | 62 | class _DistributedFairseqModel(ddp_class): 63 | """Extend DistributedDataParallel to check for missing 64 | attributes in the wrapped module.""" 65 | 66 | def __init__(self, *args, **kwargs): 67 | super().__init__(*args, **kwargs) 68 | 69 | def __getattr__(self, name): 70 | wrapped_module = super().__getattr__('module') 71 | if hasattr(wrapped_module, name): 72 | return getattr(wrapped_module, name) 73 | return super().__getattr__(name) 74 | 75 | return _DistributedFairseqModel(**init_kwargs) 76 | -------------------------------------------------------------------------------- /encdec/fairseq/models/fairseq_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class FairseqDecoder(nn.Module): 13 | """Base class for decoders.""" 14 | 15 | def __init__(self, dictionary): 16 | super().__init__() 17 | self.dictionary = dictionary 18 | 19 | def forward(self, prev_output_tokens, encoder_out): 20 | """ 21 | Args: 22 | prev_output_tokens (LongTensor): previous decoder outputs of shape 23 | `(batch, tgt_len)`, for input feeding/teacher forcing 24 | encoder_out (Tensor, optional): output from the encoder, used for 25 | encoder-side attention 26 | 27 | Returns: 28 | tuple: 29 | - the last decoder layer's output of shape 30 | `(batch, tgt_len, vocab)` 31 | - the last decoder layer's attention weights of shape 32 | `(batch, tgt_len, src_len)` 33 | """ 34 | raise NotImplementedError 35 | 36 | def get_normalized_probs(self, net_output, log_probs, sample): 37 | """Get normalized probabilities (or log probs) from a net's output.""" 38 | 39 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 40 | assert sample is not None and 'target' in sample 41 | out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target']) 42 | return out.exp_() if not log_probs else out 43 | 44 | logits = net_output[0].float() 45 | if log_probs: 46 | return F.log_softmax(logits, dim=-1) 47 | else: 48 | return F.softmax(logits, dim=-1) 49 | 50 | def max_positions(self): 51 | """Maximum input length supported by the decoder.""" 52 | return 1e6 # an arbitrary large number 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 56 | return state_dict 57 | -------------------------------------------------------------------------------- /encdec/fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class FairseqEncoder(nn.Module): 12 | """Base class for encoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__() 16 | self.dictionary = dictionary 17 | 18 | def forward(self, src_tokens, src_lengths): 19 | """ 20 | Args: 21 | src_tokens (LongTensor): tokens in the source language of shape 22 | `(batch, src_len)` 23 | src_lengths (LongTensor): lengths of each source sentence of shape 24 | `(batch)` 25 | """ 26 | raise NotImplementedError 27 | 28 | def reorder_encoder_out(self, encoder_out, new_order): 29 | """ 30 | Reorder encoder output according to `new_order`. 31 | 32 | Args: 33 | encoder_out: output from the ``forward()`` method 34 | new_order (LongTensor): desired order 35 | 36 | Returns: 37 | `encoder_out` rearranged according to `new_order` 38 | """ 39 | raise NotImplementedError 40 | 41 | def max_positions(self): 42 | """Maximum input length supported by the encoder.""" 43 | return 1e6 # an arbitrary large number 44 | 45 | def upgrade_state_dict(self, state_dict): 46 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 47 | return state_dict 48 | -------------------------------------------------------------------------------- /encdec/fairseq/models/fairseq_incremental_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqDecoder 9 | 10 | 11 | class FairseqIncrementalDecoder(FairseqDecoder): 12 | """Base class for incremental decoders. 13 | 14 | Incremental decoding is a special mode at inference time where the Model 15 | only receives a single timestep of input corresponding to the immediately 16 | previous output token (for input feeding) and must produce the next output 17 | *incrementally*. Thus the model must cache any long-term state that is 18 | needed about the sequence, e.g., hidden states, convolutional states, etc. 19 | 20 | Compared to the standard :class:`FairseqDecoder` interface, the incremental 21 | decoder interface allows :func:`forward` functions to take an extra keyword 22 | argument (*incremental_state*) that can be used to cache state across 23 | time-steps. 24 | 25 | The :class:`FairseqIncrementalDecoder` interface also defines the 26 | :func:`reorder_incremental_state` method, which is used during beam search 27 | to select and reorder the incremental state based on the selection of beams. 28 | """ 29 | 30 | def __init__(self, dictionary): 31 | super().__init__(dictionary) 32 | 33 | def forward(self, prev_output_tokens, encoder_out, incremental_state=None): 34 | """ 35 | Args: 36 | prev_output_tokens (LongTensor): previous decoder outputs of shape 37 | `(batch, tgt_len)`, for input feeding/teacher forcing 38 | encoder_out (Tensor, optional): output from the encoder, used for 39 | encoder-side attention 40 | incremental_state (dict): dictionary used for storing state during 41 | :ref:`Incremental decoding` 42 | 43 | Returns: 44 | tuple: 45 | - the last decoder layer's output of shape `(batch, tgt_len, 46 | vocab)` 47 | - the last decoder layer's attention weights of shape `(batch, 48 | tgt_len, src_len)` 49 | """ 50 | raise NotImplementedError 51 | 52 | def reorder_incremental_state(self, incremental_state, new_order): 53 | """Reorder incremental state. 54 | 55 | This should be called when the order of the input has changed from the 56 | previous time step. A typical use case is beam search, where the input 57 | order changes between time steps based on the selection of beams. 58 | """ 59 | def apply_reorder_incremental_state(module): 60 | if module != self and hasattr(module, 'reorder_incremental_state'): 61 | module.reorder_incremental_state( 62 | incremental_state, 63 | new_order, 64 | ) 65 | self.apply(apply_reorder_incremental_state) 66 | 67 | def set_beam_size(self, beam_size): 68 | """Sets the beam size in the decoder and all children.""" 69 | if getattr(self, '_beam_size', -1) != beam_size: 70 | def apply_set_beam_size(module): 71 | if module != self and hasattr(module, 'set_beam_size'): 72 | module.set_beam_size(beam_size) 73 | self.apply(apply_set_beam_size) 74 | self._beam_size = beam_size 75 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .adaptive_softmax import AdaptiveSoftmax 9 | from .beamable_mm import BeamableMM 10 | from .character_token_embedder import CharacterTokenEmbedder 11 | from .conv_tbc import ConvTBC 12 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 13 | from .grad_multiply import GradMultiply 14 | from .highway import Highway 15 | from .learned_positional_embedding import LearnedPositionalEmbedding 16 | from .linearized_convolution import LinearizedConvolution 17 | from .multihead_attention import MultiheadAttention 18 | from .scalar_bias import ScalarBias 19 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 20 | 21 | __all__ = [ 22 | 'AdaptiveSoftmax', 23 | 'BeamableMM', 24 | 'CharacterTokenEmbedder', 25 | 'ConvTBC', 26 | 'DownsampledMultiHeadAttention', 27 | 'GradMultiply', 28 | 'Highway', 29 | 'LearnedPositionalEmbedding', 30 | 'LinearizedConvolution', 31 | 'MultiheadAttention', 32 | 'ScalarBias', 33 | 'SinusoidalPositionalEmbedding', 34 | ] 35 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class BeamableMM(nn.Module): 13 | """This module provides an optimized MM for beam decoding with attention. 14 | 15 | It leverage the fact that the source-side of the input is replicated beam 16 | times and the target-side of the input is of width one. This layer speeds up 17 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 18 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 19 | """ 20 | def __init__(self, beam_size=None): 21 | super(BeamableMM, self).__init__() 22 | self.beam_size = beam_size 23 | 24 | def forward(self, input1, input2): 25 | if ( 26 | not self.training and # test mode 27 | self.beam_size is not None and # beam size is set 28 | input1.dim() == 3 and # only support batched input 29 | input1.size(1) == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from torch import nn 13 | from torch.nn.utils.rnn import pad_sequence 14 | 15 | from typing import List, Tuple 16 | 17 | from .highway import Highway 18 | from fairseq.data import Dictionary 19 | 20 | 21 | class CharacterTokenEmbedder(torch.nn.Module): 22 | def __init__( 23 | self, 24 | vocab: Dictionary, 25 | filters: List[Tuple[int, int]], 26 | char_embed_dim: int, 27 | word_embed_dim: int, 28 | highway_layers: int, 29 | max_char_len: int = 50, 30 | ): 31 | super(CharacterTokenEmbedder, self).__init__() 32 | 33 | self.embedding_dim = word_embed_dim 34 | self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) 35 | self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) 36 | self.eos_idx, self.unk_idx = 0, 1 37 | 38 | self.convolutions = nn.ModuleList() 39 | for width, out_c in filters: 40 | self.convolutions.append( 41 | nn.Conv1d(char_embed_dim, out_c, kernel_size=width) 42 | ) 43 | 44 | final_dim = sum(f[1] for f in filters) 45 | 46 | self.highway = Highway(final_dim, highway_layers) 47 | self.projection = nn.Linear(final_dim, word_embed_dim) 48 | 49 | self.set_vocab(vocab, max_char_len) 50 | self.reset_parameters() 51 | 52 | def set_vocab(self, vocab, max_char_len): 53 | word_to_char = torch.LongTensor(len(vocab), max_char_len) 54 | 55 | truncated = 0 56 | for i in range(len(vocab)): 57 | if i < vocab.nspecial: 58 | char_idxs = [0] * max_char_len 59 | else: 60 | chars = vocab[i].encode() 61 | # +1 for padding 62 | char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars)) 63 | if len(char_idxs) > max_char_len: 64 | truncated += 1 65 | char_idxs = char_idxs[:max_char_len] 66 | word_to_char[i] = torch.LongTensor(char_idxs) 67 | 68 | if truncated > 0: 69 | print('Truncated {} words longer than {} characters'.format(truncated, max_char_len)) 70 | 71 | self.vocab = vocab 72 | self.word_to_char = word_to_char 73 | 74 | @property 75 | def padding_idx(self): 76 | return self.vocab.pad() 77 | 78 | def reset_parameters(self): 79 | nn.init.xavier_normal_(self.char_embeddings.weight) 80 | nn.init.xavier_normal_(self.symbol_embeddings) 81 | nn.init.xavier_normal_(self.projection.weight) 82 | nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.) 83 | nn.init.constant_(self.projection.bias, 0.) 84 | 85 | def forward( 86 | self, 87 | words: torch.Tensor, 88 | ): 89 | self.word_to_char = self.word_to_char.type_as(words) 90 | 91 | flat_words = words.view(-1) 92 | word_embs = self._convolve(self.word_to_char[flat_words]) 93 | 94 | pads = flat_words.eq(self.vocab.pad()) 95 | if pads.any(): 96 | word_embs[pads] = 0 97 | 98 | eos = flat_words.eq(self.vocab.eos()) 99 | if eos.any(): 100 | word_embs[eos] = self.symbol_embeddings[self.eos_idx] 101 | 102 | unk = flat_words.eq(self.vocab.unk()) 103 | if unk.any(): 104 | word_embs[unk] = self.symbol_embeddings[self.unk_idx] 105 | 106 | return word_embs.view(words.size() + (-1,)) 107 | 108 | def _convolve( 109 | self, 110 | char_idxs: torch.Tensor, 111 | ): 112 | char_embs = self.char_embeddings(char_idxs) 113 | char_embs = char_embs.transpose(1, 2) # BTC -> BCT 114 | 115 | conv_result = [] 116 | 117 | for i, conv in enumerate(self.convolutions): 118 | x = conv(char_embs) 119 | x, _ = torch.max(x, -1) 120 | x = F.relu(x) 121 | conv_result.append(x) 122 | 123 | conv_result = torch.cat(conv_result, dim=-1) 124 | conv_result = self.highway(conv_result) 125 | 126 | return self.projection(conv_result) 127 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch.nn.modules.utils import _single 10 | 11 | 12 | class ConvTBC(torch.nn.Module): 13 | """1D convolution over an input of shape (time x batch x channel) 14 | 15 | The implementation uses gemm to perform the convolution. This implementation 16 | is faster than cuDNN for small kernel sizes. 17 | """ 18 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 19 | super(ConvTBC, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _single(kernel_size) 23 | self.padding = _single(padding) 24 | 25 | self.weight = torch.nn.Parameter(torch.Tensor( 26 | self.kernel_size[0], in_channels, out_channels)) 27 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 28 | 29 | def forward(self, input): 30 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 31 | 32 | def __repr__(self): 33 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 34 | ', padding={padding}') 35 | if self.bias is None: 36 | s += ', bias=False' 37 | s += ')' 38 | return s.format(name=self.__class__.__name__, **self.__dict__) 39 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class GradMultiply(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, scale): 14 | ctx.scale = scale 15 | res = x.new(x) 16 | return res 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | return grad * ctx.scale, None 21 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/highway.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from torch import nn 12 | 13 | 14 | class Highway(torch.nn.Module): 15 | """ 16 | A `Highway layer <https://arxiv.org/abs/1505.00387>`_. 17 | Adopted from the AllenNLP implementation. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int, 23 | num_layers: int = 1 24 | ): 25 | super(Highway, self).__init__() 26 | self.input_dim = input_dim 27 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) 28 | for _ in range(num_layers)]) 29 | self.activation = nn.ReLU() 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | for layer in self.layers: 35 | # As per comment in AllenNLP: 36 | # We should bias the highway layer to just carry its input forward. We do that by 37 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 38 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 39 | # of the bias vector in each Linear layer. 40 | nn.init.constant_(layer.bias[self.input_dim:], 1) 41 | 42 | nn.init.constant_(layer.bias[:self.input_dim], 0) 43 | nn.init.xavier_normal_(layer.weight) 44 | 45 | def forward( 46 | self, 47 | x: torch.Tensor 48 | ): 49 | for layer in self.layers: 50 | projection = layer(x) 51 | proj_x, gate = projection.chunk(2, dim=-1) 52 | proj_x = self.activation(proj_x) 53 | gate = F.sigmoid(gate) 54 | x = gate * x + (1 - gate) * proj_x 55 | return x 56 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class LearnedPositionalEmbedding(nn.Embedding): 14 | """This module learns positional embeddings up to a fixed maximum size. 15 | 16 | Padding symbols are ignored, but it is necessary to specify whether padding 17 | is added on the left side (left_pad=True) or right side (left_pad=False). 18 | """ 19 | 20 | def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): 21 | super().__init__(num_embeddings, embedding_dim, padding_idx) 22 | self.left_pad = left_pad 23 | 24 | def forward(self, input, incremental_state=None): 25 | """Input is expected to be of size [bsz x seqlen].""" 26 | if incremental_state is not None: 27 | # positions is the same for every token when decoding a single step 28 | positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) 29 | else: 30 | positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) 31 | return super().forward(positions) 32 | 33 | def max_positions(self): 34 | """Maximum number of supported positions.""" 35 | return self.num_embeddings - self.padding_idx - 1 36 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/linearized_convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from .conv_tbc import ConvTBC 14 | 15 | 16 | class LinearizedConvolution(ConvTBC): 17 | """An optimized version of nn.Conv1d. 18 | 19 | At training time, this module uses ConvTBC, which is an optimized version 20 | of Conv1d. At inference time, it optimizes incremental generation (i.e., 21 | one time step at a time) by replacing the convolutions with linear layers. 22 | Note that the input order changes from training to inference. 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 26 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 27 | self._linearized_weight = None 28 | self.register_backward_hook(self._clear_linearized_weight) 29 | 30 | def forward(self, input, incremental_state=None): 31 | """ 32 | Args: 33 | incremental_state: Used to buffer signal; if not None, then input is 34 | expected to contain a single frame. If the input order changes 35 | between time steps, call reorder_incremental_state. 36 | Input: 37 | Time x Batch x Channel during training 38 | Batch x Time x Channel during inference 39 | """ 40 | if incremental_state is None: 41 | output = super().forward(input) 42 | if self.kernel_size[0] > 1 and self.padding[0] > 0: 43 | # remove future timesteps added by padding 44 | output = output[:-self.padding[0], :, :] 45 | return output 46 | 47 | # reshape weight 48 | weight = self._get_linearized_weight() 49 | kw = self.kernel_size[0] 50 | 51 | bsz = input.size(0) # input: bsz x len x dim 52 | if kw > 1: 53 | input = input.data 54 | input_buffer = self._get_input_buffer(incremental_state) 55 | if input_buffer is None: 56 | input_buffer = input.new(bsz, kw, input.size(2)).zero_() 57 | self._set_input_buffer(incremental_state, input_buffer) 58 | else: 59 | # shift buffer 60 | input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() 61 | # append next input 62 | input_buffer[:, -1, :] = input[:, -1, :] 63 | input = input_buffer 64 | with torch.no_grad(): 65 | output = F.linear(input.view(bsz, -1), weight, self.bias) 66 | return output.view(bsz, 1, -1) 67 | 68 | def reorder_incremental_state(self, incremental_state, new_order): 69 | input_buffer = self._get_input_buffer(incremental_state) 70 | if input_buffer is not None: 71 | input_buffer = input_buffer.index_select(0, new_order) 72 | self._set_input_buffer(incremental_state, input_buffer) 73 | 74 | def _get_input_buffer(self, incremental_state): 75 | return utils.get_incremental_state(self, incremental_state, 'input_buffer') 76 | 77 | def _set_input_buffer(self, incremental_state, new_buffer): 78 | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) 79 | 80 | def _get_linearized_weight(self): 81 | if self._linearized_weight is None: 82 | kw = self.kernel_size[0] 83 | weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() 84 | assert weight.size() == (self.out_channels, kw, self.in_channels) 85 | self._linearized_weight = weight.view(self.out_channels, -1) 86 | return self._linearized_weight 87 | 88 | def _clear_linearized_weight(self, *args): 89 | self._linearized_weight = None 90 | -------------------------------------------------------------------------------- /encdec/fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | import torch 10 | 11 | 12 | class ScalarBias(torch.autograd.Function): 13 | """ 14 | Adds a vector of scalars, used in self-attention mechanism to allow 15 | the model to optionally attend to this vector instead of the past 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, input, dim, bias_init): 20 | size = list(input.size()) 21 | size[dim] += 1 22 | output = input.new(*size).fill_(bias_init) 23 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 24 | ctx.dim = dim 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad): 29 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 30 | 31 | 32 | def scalar_bias(input, dim, bias_init=0): 33 | return ScalarBias.apply(input, dim, bias_init) 34 | -------------------------------------------------------------------------------- /encdec/fairseq/multiprocessing_pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import multiprocessing 9 | import os 10 | import pdb 11 | import sys 12 | 13 | 14 | class MultiprocessingPdb(pdb.Pdb): 15 | """A Pdb wrapper that works in a multiprocessing environment. 16 | 17 | Usage: `from fairseq import pdb; pdb.set_trace()` 18 | """ 19 | 20 | _stdin_fd = sys.stdin.fileno() 21 | _stdin = None 22 | _stdin_lock = multiprocessing.Lock() 23 | 24 | def __init__(self): 25 | pdb.Pdb.__init__(self, nosigint=True) 26 | 27 | def _cmdloop(self): 28 | stdin_bak = sys.stdin 29 | with self._stdin_lock: 30 | try: 31 | if not self._stdin: 32 | self._stdin = os.fdopen(self._stdin_fd) 33 | sys.stdin = self._stdin 34 | self.cmdloop() 35 | finally: 36 | sys.stdin = stdin_bak 37 | 38 | 39 | pdb = MultiprocessingPdb() 40 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_optimizer import FairseqOptimizer 12 | from .fp16_optimizer import FP16Optimizer 13 | 14 | 15 | OPTIMIZER_REGISTRY = {} 16 | OPTIMIZER_CLASS_NAMES = set() 17 | 18 | 19 | def build_optimizer(args, params): 20 | params = list(filter(lambda p: p.requires_grad, params)) 21 | return OPTIMIZER_REGISTRY[args.optimizer](args, params) 22 | 23 | 24 | def register_optimizer(name): 25 | """Decorator to register a new optimizer.""" 26 | 27 | def register_optimizer_cls(cls): 28 | if name in OPTIMIZER_REGISTRY: 29 | raise ValueError('Cannot register duplicate optimizer ({})'.format(name)) 30 | if not issubclass(cls, FairseqOptimizer): 31 | raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__)) 32 | if cls.__name__ in OPTIMIZER_CLASS_NAMES: 33 | # We use the optimizer class name as a unique identifier in 34 | # checkpoints, so all optimizer must have unique class names. 35 | raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__)) 36 | OPTIMIZER_REGISTRY[name] = cls 37 | OPTIMIZER_CLASS_NAMES.add(cls.__name__) 38 | return cls 39 | 40 | return register_optimizer_cls 41 | 42 | 43 | # automatically import any Python files in the optim/ directory 44 | for file in os.listdir(os.path.dirname(__file__)): 45 | if file.endswith('.py') and not file.startswith('_'): 46 | module = file[:file.find('.py')] 47 | importlib.import_module('fairseq.optim.' + module) 48 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adagrad') 14 | class Adagrad(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'weight_decay': self.args.weight_decay, 30 | } 31 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/fairseq_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import torch 11 | 12 | 13 | class FairseqOptimizer(object): 14 | 15 | def __init__(self, args, params): 16 | super().__init__() 17 | self.args = args 18 | self.params = list(params) 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | """Add optimizer-specific arguments to the parser.""" 23 | pass 24 | 25 | @property 26 | def optimizer(self): 27 | """Return a torch.optim.optimizer.Optimizer instance.""" 28 | if not hasattr(self, '_optimizer'): 29 | raise NotImplementedError 30 | if not isinstance(self._optimizer, torch.optim.Optimizer): 31 | raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') 32 | return self._optimizer 33 | 34 | @property 35 | def optimizer_config(self): 36 | """ 37 | Return a kwarg dictionary that will be used to override optimizer 38 | args stored in checkpoints. This allows us to load a checkpoint and 39 | resume training using a different set of optimizer args, e.g., with a 40 | different learning rate. 41 | """ 42 | raise NotImplementedError 43 | 44 | def get_lr(self): 45 | """Return the current learning rate.""" 46 | return self.optimizer.param_groups[0]['lr'] 47 | 48 | def set_lr(self, lr): 49 | """Set the learning rate.""" 50 | for param_group in self.optimizer.param_groups: 51 | param_group['lr'] = lr 52 | 53 | def state_dict(self): 54 | """Return the optimizer's state dict.""" 55 | return self.optimizer.state_dict() 56 | 57 | def load_state_dict(self, state_dict, optimizer_overrides=None): 58 | """Load an optimizer state dict. 59 | 60 | In general we should prefer the configuration of the existing optimizer 61 | instance (e.g., learning rate) over that found in the state_dict. This 62 | allows us to resume training from a checkpoint using a new set of 63 | optimizer args. 64 | """ 65 | self.optimizer.load_state_dict(state_dict) 66 | 67 | if optimizer_overrides is not None and len(optimizer_overrides) > 0: 68 | # override learning rate, momentum, etc. with latest values 69 | for group in self.optimizer.param_groups: 70 | group.update(optimizer_overrides) 71 | 72 | def backward(self, loss): 73 | loss.backward() 74 | 75 | def multiply_grads(self, c): 76 | """Multiplies grads by a constant ``c``.""" 77 | for p in self.params: 78 | p.grad.data.mul_(c) 79 | 80 | def clip_grad_norm(self, max_norm): 81 | """Clips gradient norm.""" 82 | if max_norm > 0: 83 | return torch.nn.utils.clip_grad_norm_(self.params, max_norm) 84 | else: 85 | return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params)) 86 | 87 | def step(self, closure=None): 88 | """Performs a single optimization step.""" 89 | self.optimizer.step(closure) 90 | 91 | def zero_grad(self): 92 | """Clears the gradients of all optimized parameters.""" 93 | self.optimizer.zero_grad() 94 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_lr_scheduler import FairseqLRScheduler 12 | 13 | 14 | LR_SCHEDULER_REGISTRY = {} 15 | 16 | 17 | def build_lr_scheduler(args, optimizer): 18 | return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer) 19 | 20 | 21 | def register_lr_scheduler(name): 22 | """Decorator to register a new LR scheduler.""" 23 | 24 | def register_lr_scheduler_cls(cls): 25 | if name in LR_SCHEDULER_REGISTRY: 26 | raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name)) 27 | if not issubclass(cls, FairseqLRScheduler): 28 | raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__)) 29 | LR_SCHEDULER_REGISTRY[name] = cls 30 | return cls 31 | 32 | return register_lr_scheduler_cls 33 | 34 | 35 | # automatically import any Python files in the optim/lr_scheduler/ directory 36 | for file in os.listdir(os.path.dirname(__file__)): 37 | if file.endswith('.py') and not file.startswith('_'): 38 | module = file[:file.find('.py')] 39 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 40 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('cosine') 14 | class CosineSchedule(FairseqLRScheduler): 15 | """Assign LR based on a cyclical schedule that follows the cosine function. 16 | See https://arxiv.org/pdf/1608.03983.pdf for details 17 | We also support a warmup phase where we linearly increase the learning rate 18 | from some initial learning rate (`--warmup-init-lr`) until the configured 19 | learning rate (`--lr`). 20 | During warmup: 21 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 22 | lr = lrs[update_num] 23 | After warmup: 24 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) 25 | where 26 | t_curr is current percentage of updates within the current period range 27 | t_i is the current period range, which is scaled by t_mul after every iteration 28 | """ 29 | 30 | def __init__(self, args, optimizer): 31 | super().__init__(args, optimizer) 32 | if len(args.lr) > 1: 33 | raise ValueError( 34 | 'Cannot use a fixed learning rate schedule with cosine.' 35 | ' Consider --lr-scheduler=fixed instead.' 36 | ) 37 | 38 | warmup_end_lr = args.max_lr 39 | if args.warmup_init_lr < 0: 40 | args.warmup_init_lr = args.lr[0] 41 | 42 | self.min_lr = args.lr[0] 43 | self.max_lr = args.max_lr 44 | 45 | assert self.max_lr > self.min_lr, 'max_lr must be more than lr' 46 | 47 | self.t_mult = args.t_mult 48 | self.period = args.lr_period_updates 49 | 50 | if args.warmup_updates > 0: 51 | # linearly warmup for the first args.warmup_updates 52 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 53 | else: 54 | self.lr_step = 1 55 | 56 | self.warmup_updates = args.warmup_updates 57 | self.lr_shrink = args.lr_shrink 58 | 59 | # initial learning rate 60 | self.lr = args.warmup_init_lr 61 | self.optimizer.set_lr(self.lr) 62 | 63 | @staticmethod 64 | def add_args(parser): 65 | """Add arguments to the parser for this LR scheduler.""" 66 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 67 | help='warmup the learning rate linearly for the first N updates') 68 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 69 | help='initial learning rate during warmup phase; default is args.lr') 70 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 71 | help='max learning rate, must be more than args.lr') 72 | parser.add_argument('--t-mult', default=1, type=float, metavar='LR', 73 | help='factor to grow the length of each period') 74 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 75 | help='initial number of updates per period') 76 | 77 | def step(self, epoch, val_loss=None): 78 | """Update the learning rate at the end of the given epoch.""" 79 | super().step(epoch, val_loss) 80 | # we don't change the learning rate at epoch boundaries 81 | return self.optimizer.get_lr() 82 | 83 | def step_update(self, num_updates): 84 | """Update the learning rate after each update.""" 85 | if num_updates < self.args.warmup_updates: 86 | self.lr = self.args.warmup_init_lr + num_updates * self.lr_step 87 | else: 88 | curr_updates = num_updates - self.args.warmup_updates 89 | if self.t_mult != 1: 90 | i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult)) 91 | t_i = self.t_mult ** i * self.period 92 | t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period 93 | else: 94 | i = math.floor(curr_updates / self.period) 95 | t_i = self.period 96 | t_curr = curr_updates - (self.period * i) 97 | 98 | lr_shrink = self.lr_shrink ** i 99 | min_lr = self.min_lr * lr_shrink 100 | max_lr = self.max_lr * lr_shrink 101 | 102 | self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 103 | 104 | self.optimizer.set_lr(self.lr) 105 | return self.lr -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .. import FairseqOptimizer 9 | 10 | 11 | class FairseqLRScheduler(object): 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__() 15 | if not isinstance(optimizer, FairseqOptimizer): 16 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 17 | self.args = args 18 | self.optimizer = optimizer 19 | self.best = None 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add arguments to the parser for this LR scheduler.""" 24 | pass 25 | 26 | def state_dict(self): 27 | """Return the LR scheduler state dict.""" 28 | return {'best': self.best} 29 | 30 | def load_state_dict(self, state_dict): 31 | """Load an LR scheduler state dict.""" 32 | self.best = state_dict['best'] 33 | 34 | def step(self, epoch, val_loss=None): 35 | """Update the learning rate at the end of the given epoch.""" 36 | if val_loss is not None: 37 | if self.best is None: 38 | self.best = val_loss 39 | else: 40 | self.best = min(self.best, val_loss) 41 | 42 | def step_update(self, num_updates): 43 | """Update the learning rate after each update.""" 44 | return self.optimizer.get_lr() 45 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('fixed') 12 | class FixedSchedule(FairseqLRScheduler): 13 | """Decay the LR on a fixed schedule.""" 14 | 15 | def __init__(self, args, optimizer): 16 | super().__init__(args, optimizer) 17 | 18 | # set defaults 19 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 20 | 21 | self.lr = args.lr[0] 22 | if args.warmup_updates > 0: 23 | self.warmup_factor = 1. / args.warmup_updates 24 | else: 25 | self.warmup_factor = 1 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | """Add arguments to the parser for this LR scheduler.""" 30 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 31 | help='force annealing at specified epoch') 32 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 33 | help='warmup the learning rate linearly for the first N updates') 34 | 35 | def get_next_lr(self, epoch): 36 | lrs = self.args.lr 37 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 38 | # use fixed LR schedule 39 | next_lr = lrs[min(epoch, len(lrs) - 1)] 40 | else: 41 | # annneal based on lr_shrink 42 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 43 | return next_lr 44 | 45 | def step(self, epoch, val_loss=None): 46 | """Update the learning rate at the end of the given epoch.""" 47 | super().step(epoch, val_loss) 48 | self.lr = self.get_next_lr(epoch) 49 | self.optimizer.set_lr(self.warmup_factor * self.lr) 50 | return self.optimizer.get_lr() 51 | 52 | def step_update(self, num_updates): 53 | """Update the learning rate after each update.""" 54 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 55 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 56 | self.optimizer.set_lr(self.warmup_factor * self.lr) 57 | return self.optimizer.get_lr() 58 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('inverse_sqrt') 12 | class InverseSquareRootSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (`--warmup-init-lr`) until the configured 17 | learning rate (`--lr`). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup: 26 | 27 | lr = decay_factor / sqrt(update_num) 28 | 29 | where 30 | 31 | decay_factor = args.lr * sqrt(args.warmup_updates) 32 | """ 33 | 34 | def __init__(self, args, optimizer): 35 | super().__init__(args, optimizer) 36 | if len(args.lr) > 1: 37 | raise ValueError( 38 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 39 | ' Consider --lr-scheduler=fixed instead.' 40 | ) 41 | warmup_end_lr = args.lr[0] 42 | if args.warmup_init_lr < 0: 43 | args.warmup_init_lr = warmup_end_lr 44 | 45 | # linearly warmup for the first args.warmup_updates 46 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 47 | 48 | # then, decay prop. to the inverse square root of the update number 49 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 50 | 51 | # initial learning rate 52 | self.lr = args.warmup_init_lr 53 | self.optimizer.set_lr(self.lr) 54 | 55 | @staticmethod 56 | def add_args(parser): 57 | """Add arguments to the parser for this LR scheduler.""" 58 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 59 | help='warmup the learning rate linearly for the first N updates') 60 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 61 | help='initial learning rate during warmup phase; default is args.lr') 62 | 63 | def step(self, epoch, val_loss=None): 64 | """Update the learning rate at the end of the given epoch.""" 65 | super().step(epoch, val_loss) 66 | # we don't change the learning rate at epoch boundaries 67 | return self.optimizer.get_lr() 68 | 69 | def step_update(self, num_updates): 70 | """Update the learning rate after each update.""" 71 | if num_updates < self.args.warmup_updates: 72 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 73 | else: 74 | self.lr = self.decay_factor * num_updates**-0.5 75 | self.optimizer.set_lr(self.lr) 76 | return self.lr 77 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim.lr_scheduler 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('reduce_lr_on_plateau') 14 | class ReduceLROnPlateau(FairseqLRScheduler): 15 | """Decay the LR by a factor every time the validation loss plateaus.""" 16 | 17 | def __init__(self, args, optimizer): 18 | super().__init__(args, optimizer) 19 | if len(args.lr) > 1: 20 | raise ValueError( 21 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' 22 | ' Consider --lr-scheduler=fixed instead.' 23 | ) 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer.optimizer, patience=0, factor=args.lr_shrink) 26 | 27 | def state_dict(self): 28 | """Return the LR scheduler state dict.""" 29 | return { 30 | 'best': self.lr_scheduler.best, 31 | 'last_epoch': self.lr_scheduler.last_epoch, 32 | } 33 | 34 | def load_state_dict(self, state_dict): 35 | """Load an LR scheduler state dict.""" 36 | self.lr_scheduler.best = state_dict['best'] 37 | if 'last_epoch' in state_dict: 38 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 39 | 40 | def step(self, epoch, val_loss=None): 41 | """Update the learning rate at the end of the given epoch.""" 42 | if val_loss is not None: 43 | self.lr_scheduler.step(val_loss, epoch) 44 | else: 45 | self.lr_scheduler.last_epoch = epoch 46 | return self.optimizer.get_lr() 47 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('triangular') 14 | class TriangularSchedule(FairseqLRScheduler): 15 | """Assign LR based on a triangular cyclical schedule. 16 | 17 | See https://arxiv.org/pdf/1506.01186.pdf for details 18 | 19 | """ 20 | 21 | def __init__(self, args, optimizer): 22 | super().__init__(args, optimizer) 23 | if len(args.lr) > 1: 24 | raise ValueError( 25 | 'Cannot use a fixed learning rate schedule with triangular.' 26 | ' Consider --lr-scheduler=fixed instead.' 27 | ) 28 | 29 | lr = args.lr[0] 30 | 31 | assert args.max_lr > lr, 'max_lr must be more than lr' 32 | self.min_lr = lr 33 | self.max_lr = args.max_lr 34 | self.stepsize = args.lr_period_updates // 2 35 | self.lr_shrink = args.lr_shrink 36 | self.shrink_min = args.shrink_min 37 | 38 | # initial learning rate 39 | self.lr = self.min_lr 40 | self.optimizer.set_lr(self.lr) 41 | 42 | @staticmethod 43 | def add_args(parser): 44 | """Add arguments to the parser for this LR scheduler.""" 45 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 46 | help='max learning rate, must be more than args.lr') 47 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 48 | help='initial number of updates per period (cycle length)') 49 | parser.add_argument('--shrink-min', action='store_true', 50 | help='if set, also shrinks min lr') 51 | 52 | def step(self, epoch, val_loss=None): 53 | """Update the learning rate at the end of the given epoch.""" 54 | super().step(epoch, val_loss) 55 | # we don't change the learning rate at epoch boundaries 56 | return self.optimizer.get_lr() 57 | 58 | def step_update(self, num_updates): 59 | """Update the learning rate after each update.""" 60 | cycle = math.floor(num_updates / (2 * self.stepsize)) 61 | 62 | lr_shrink = self.lr_shrink ** cycle 63 | max_lr = self.max_lr * lr_shrink 64 | if self.shrink_min: 65 | min_lr = self.min_lr * lr_shrink 66 | else: 67 | min_lr = self.min_lr 68 | 69 | x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) 70 | self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) 71 | 72 | self.optimizer.set_lr(self.lr) 73 | return self.lr 74 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/nag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.optim.optimizer import Optimizer, required 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('nag') 14 | class FairseqNAG(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = NAG(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | 33 | 34 | class NAG(Optimizer): 35 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 36 | defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) 37 | super(NAG, self).__init__(params, defaults) 38 | 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | weight_decay = group['weight_decay'] 52 | momentum = group['momentum'] 53 | lr = group['lr'] 54 | lr_old = group.get('lr_old', lr) 55 | lr_correct = lr / lr_old 56 | 57 | for p in group['params']: 58 | if p.grad is None: 59 | continue 60 | 61 | d_p = p.grad.data 62 | param_state = self.state[p] 63 | if 'momentum_buffer' not in param_state: 64 | param_state['momentum_buffer'] = d_p.clone().zero_() 65 | 66 | buf = param_state['momentum_buffer'] 67 | 68 | if weight_decay != 0: 69 | p.data.mul_(1 - lr * weight_decay) 70 | p.data.add_(momentum * momentum * lr_correct, buf) 71 | p.data.add_(-(1 + momentum) * lr, d_p) 72 | 73 | buf.mul_(momentum * lr_correct).add_(-lr, d_p) 74 | 75 | group['lr_old'] = lr 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /encdec/fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('sgd') 14 | class SGD(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | -------------------------------------------------------------------------------- /encdec/fairseq/sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | from fairseq import utils 11 | 12 | 13 | class SequenceScorer(object): 14 | """Scores the target for a given source sentence.""" 15 | 16 | def __init__(self, models, tgt_dict): 17 | self.models = models 18 | self.pad = tgt_dict.pad() 19 | 20 | def cuda(self): 21 | for model in self.models: 22 | model.cuda() 23 | return self 24 | 25 | def score_batched_itr(self, data_itr, cuda=False, timer=None): 26 | """Iterate over a batched dataset and yield scored translations.""" 27 | for sample in data_itr: 28 | s = utils.move_to_cuda(sample) if cuda else sample 29 | if timer is not None: 30 | timer.start() 31 | pos_scores, attn = self.score(s) 32 | for i, id in enumerate(s['id'].data): 33 | # remove padding from ref 34 | src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad) 35 | ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None 36 | tgt_len = ref.numel() 37 | pos_scores_i = pos_scores[i][:tgt_len] 38 | score_i = pos_scores_i.sum() / tgt_len 39 | if attn is not None: 40 | attn_i = attn[i] 41 | _, alignment = attn_i.max(dim=0) 42 | else: 43 | attn_i = alignment = None 44 | hypos = [{ 45 | 'tokens': ref, 46 | 'score': score_i, 47 | 'attention': attn_i, 48 | 'alignment': alignment, 49 | 'positional_scores': pos_scores_i, 50 | }] 51 | if timer is not None: 52 | timer.stop(s['ntokens']) 53 | # return results in the same format as SequenceGenerator 54 | yield id, src, ref, hypos 55 | 56 | def score(self, sample): 57 | """Score a batch of translations.""" 58 | net_input = sample['net_input'] 59 | 60 | # compute scores for each model in the ensemble 61 | avg_probs = None 62 | avg_attn = None 63 | for model in self.models: 64 | with torch.no_grad(): 65 | model.eval() 66 | decoder_out = model.forward(**net_input) 67 | attn = decoder_out[1] 68 | 69 | probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data 70 | if avg_probs is None: 71 | avg_probs = probs 72 | else: 73 | avg_probs.add_(probs) 74 | if attn is not None and torch.is_tensor(attn): 75 | attn = attn.data 76 | if avg_attn is None: 77 | avg_attn = attn 78 | else: 79 | avg_attn.add_(attn) 80 | avg_probs.div_(len(self.models)) 81 | avg_probs.log_() 82 | if avg_attn is not None: 83 | avg_attn.div_(len(self.models)) 84 | avg_probs = avg_probs.gather( 85 | dim=2, 86 | index=sample['target'].data.unsqueeze(-1), 87 | ) 88 | return avg_probs.squeeze(2), avg_attn 89 | -------------------------------------------------------------------------------- /encdec/fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights 5 | # can be found in the PATENTS file in the same directory. 6 | 7 | import argparse 8 | import importlib 9 | import os 10 | 11 | from .fairseq_task import FairseqTask 12 | 13 | 14 | TASK_REGISTRY = {} 15 | TASK_CLASS_NAMES = set() 16 | 17 | 18 | def setup_task(args): 19 | return TASK_REGISTRY[args.task].setup_task(args) 20 | 21 | 22 | def register_task(name): 23 | """ 24 | New tasks can be added to fairseq with the 25 | :func:`~fairseq.tasks.register_task` function decorator. 26 | 27 | For example:: 28 | 29 | @register_task('classification') 30 | class ClassificationTask(FairseqTask): 31 | (...) 32 | 33 | .. note:: 34 | 35 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 36 | interface. 37 | 38 | Please see the 39 | 40 | Args: 41 | name (str): the name of the task 42 | """ 43 | 44 | def register_task_cls(cls): 45 | if name in TASK_REGISTRY: 46 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 47 | if not issubclass(cls, FairseqTask): 48 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 49 | if cls.__name__ in TASK_CLASS_NAMES: 50 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 51 | TASK_REGISTRY[name] = cls 52 | TASK_CLASS_NAMES.add(cls.__name__) 53 | return cls 54 | 55 | return register_task_cls 56 | 57 | 58 | # automatically import any Python files in the tasks/ directory 59 | for file in os.listdir(os.path.dirname(__file__)): 60 | if file.endswith('.py') and not file.startswith('_'): 61 | task_name = file[:file.find('.py')] 62 | importlib.import_module('fairseq.tasks.' + task_name) 63 | 64 | # expose `task_parser` for sphinx 65 | if task_name in TASK_REGISTRY: 66 | parser = argparse.ArgumentParser(add_help=False) 67 | group_task = parser.add_argument_group('Task name') 68 | group_task.add_argument( 69 | '--task', metavar=task_name, 70 | help='Enable this task with: ``--task=' + task_name + '``' 71 | ) 72 | group_args = parser.add_argument_group('Additional command-line arguments') 73 | TASK_REGISTRY[task_name].add_args(group_args) 74 | globals()[task_name + '_parser'] = parser 75 | -------------------------------------------------------------------------------- /encdec/multiprocessing_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import os 10 | import random 11 | import signal 12 | import torch 13 | 14 | from fairseq import distributed_utils, options 15 | 16 | from train import main as single_process_main 17 | 18 | 19 | def main(args): 20 | # Set distributed training parameters for a single node. 21 | args.distributed_world_size = torch.cuda.device_count() 22 | port = random.randint(10000, 20000) 23 | args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) 24 | args.distributed_init_host = 'localhost' 25 | args.distributed_port = port + 1 26 | 27 | mp = torch.multiprocessing.get_context('spawn') 28 | 29 | # Create a thread to listen for errors in the child processes. 30 | error_queue = mp.SimpleQueue() 31 | error_handler = ErrorHandler(error_queue) 32 | 33 | # Train with multiprocessing. 34 | procs = [] 35 | for i in range(args.distributed_world_size): 36 | args.distributed_rank = i 37 | args.device_id = i 38 | procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True)) 39 | procs[i].start() 40 | error_handler.add_child(procs[i].pid) 41 | for p in procs: 42 | p.join() 43 | 44 | 45 | def run(args, error_queue): 46 | try: 47 | args.distributed_rank = distributed_utils.distributed_init(args) 48 | single_process_main(args) 49 | except KeyboardInterrupt: 50 | pass # killed by parent, do nothing 51 | except Exception: 52 | # propagate exception to parent process, keeping original traceback 53 | import traceback 54 | error_queue.put((args.distributed_rank, traceback.format_exc())) 55 | 56 | 57 | class ErrorHandler(object): 58 | """A class that listens for exceptions in children processes and propagates 59 | the tracebacks to the parent process.""" 60 | 61 | def __init__(self, error_queue): 62 | import signal 63 | import threading 64 | self.error_queue = error_queue 65 | self.children_pids = [] 66 | self.error_thread = threading.Thread(target=self.error_listener, daemon=True) 67 | self.error_thread.start() 68 | signal.signal(signal.SIGUSR1, self.signal_handler) 69 | 70 | def add_child(self, pid): 71 | self.children_pids.append(pid) 72 | 73 | def error_listener(self): 74 | (rank, original_trace) = self.error_queue.get() 75 | self.error_queue.put((rank, original_trace)) 76 | os.kill(os.getpid(), signal.SIGUSR1) 77 | 78 | def signal_handler(self, signalnum, stackframe): 79 | for pid in self.children_pids: 80 | os.kill(pid, signal.SIGINT) # kill children processes 81 | (rank, original_trace) = self.error_queue.get() 82 | msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n" 83 | msg += original_trace 84 | raise Exception(msg) 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = options.get_training_parser() 89 | args = options.parse_args_and_arch(parser) 90 | main(args) 91 | -------------------------------------------------------------------------------- /encdec/requirements.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | numpy 3 | torch 4 | tqdm 5 | -------------------------------------------------------------------------------- /encdec/rerank.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #extract a sentence which contains the most source tokens 4 | 5 | import sys 6 | import re 7 | import collections 8 | import argparse 9 | from nltk.stem.porter import PorterStemmer 10 | 11 | alphabet = re.compile('[a-z]') 12 | 13 | 14 | def extract_source_word(filename, stemmer=None): 15 | d = {} 16 | for i, line in enumerate(open(filename)): 17 | line = line.strip() 18 | words = line.split() 19 | if stemmer is None: 20 | wordset = set([w for w in words if alphabet.search(w)]) 21 | else: 22 | wordset = set([stemmer.stem(w) for w in words if alphabet.search(w)]) 23 | d[i] = wordset 24 | return d 25 | 26 | 27 | def extract_output(filename, space_symbol): 28 | d = collections.defaultdict(list) 29 | for line in open(filename): 30 | line = line.strip() 31 | if not line.startswith('H-'): 32 | continue 33 | index, prob, cand = line.split('\t') 34 | index = int(index.replace('H-', '')) 35 | d[index].append(cand.replace(' ', '').replace(space_symbol, ' ')) 36 | return d 37 | 38 | 39 | def main(args): 40 | if args.m: 41 | stemmer = PorterStemmer() 42 | else: 43 | stemmer = None 44 | source_word_dict = extract_source_word(args.source, stemmer) 45 | output_dict = extract_output(args.cand, args.space_symbol) 46 | for index in range(len(output_dict)): 47 | wordset = source_word_dict[index] 48 | out = None 49 | for cand in output_dict[index]: 50 | if stemmer is None: 51 | word = cand.split() 52 | else: 53 | word = [stemmer.stem(w) for w in cand.split()] 54 | num = sum([1 for w in word if w in wordset]) 55 | if out is None: 56 | out = cand 57 | maxnum = num 58 | elif num > maxnum: 59 | out = cand 60 | maxnum = num 61 | print(out) 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--cand', required=True, 67 | help='specify the candidate file') 68 | parser.add_argument('--source', required=True, 69 | help='specify the source file') 70 | parser.add_argument('-m', default=False, action='store_true', 71 | help='stemming (by porter stemmer) or not') 72 | parser.add_argument('--space', dest='space_symbol', default='@@@@', 73 | help='symbol to represent a space in the candidate file') 74 | args = parser.parse_args() 75 | main(args) 76 | -------------------------------------------------------------------------------- /encdec/score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | """ 9 | BLEU scoring of generated translations against reference translations. 10 | """ 11 | 12 | import argparse 13 | import os 14 | import sys 15 | 16 | from fairseq import bleu, tokenizer 17 | from fairseq.data import dictionary 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') 22 | parser.add_argument('-s', '--sys', default='-', help='system output') 23 | parser.add_argument('-r', '--ref', required=True, help='references') 24 | parser.add_argument('-o', '--order', default=4, metavar='N', 25 | type=int, help='consider ngrams up to this order') 26 | parser.add_argument('--ignore-case', action='store_true', 27 | help='case-insensitive scoring') 28 | return parser 29 | 30 | 31 | def main(): 32 | parser = get_parser() 33 | args = parser.parse_args() 34 | print(args) 35 | 36 | assert args.sys == '-' or os.path.exists(args.sys), \ 37 | "System output file {} does not exist".format(args.sys) 38 | assert os.path.exists(args.ref), \ 39 | "Reference file {} does not exist".format(args.ref) 40 | 41 | dict = dictionary.Dictionary() 42 | 43 | def readlines(fd): 44 | for line in fd.readlines(): 45 | if args.ignore_case: 46 | yield line.lower() 47 | yield line 48 | 49 | def score(fdsys): 50 | with open(args.ref) as fdref: 51 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 52 | for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): 53 | sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) 54 | ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) 55 | scorer.add(ref_tok, sys_tok) 56 | print(scorer.result_string(args.order)) 57 | 58 | if args.sys == '-': 59 | score(sys.stdin) 60 | else: 61 | with open(args.sys, 'r') as f: 62 | score(f) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /encdec/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takase/control-length/c9eb632e112f07a156a5892065969cda788253a0/encdec/scripts/__init__.py -------------------------------------------------------------------------------- /encdec/scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import collections 5 | import torch 6 | import os 7 | import re 8 | 9 | 10 | def average_checkpoints(inputs): 11 | """Loads checkpoints from inputs and returns a model with averaged weights. 12 | 13 | Args: 14 | inputs: An iterable of string paths of checkpoints to load from. 15 | 16 | Returns: 17 | A dict of string keys mapping to various values. The 'model' key 18 | from the returned dict should correspond to an OrderedDict mapping 19 | string parameter names to torch Tensors. 20 | """ 21 | params_dict = collections.OrderedDict() 22 | params_keys = None 23 | new_state = None 24 | for f in inputs: 25 | state = torch.load( 26 | f, 27 | map_location=( 28 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 29 | ), 30 | ) 31 | # Copies over the settings from the first checkpoint 32 | if new_state is None: 33 | new_state = state 34 | 35 | model_params = state['model'] 36 | 37 | model_params_keys = list(model_params.keys()) 38 | if params_keys is None: 39 | params_keys = model_params_keys 40 | elif params_keys != model_params_keys: 41 | raise KeyError( 42 | 'For checkpoint {}, expected list of params: {}, ' 43 | 'but found: {}'.format(f, params_keys, model_params_keys) 44 | ) 45 | 46 | for k in params_keys: 47 | if k not in params_dict: 48 | params_dict[k] = [] 49 | p = model_params[k] 50 | if isinstance(p, torch.HalfTensor): 51 | p = p.float() 52 | params_dict[k].append(p) 53 | 54 | averaged_params = collections.OrderedDict() 55 | # v should be a list of torch Tensor. 56 | for k, v in params_dict.items(): 57 | summed_v = None 58 | for x in v: 59 | summed_v = summed_v + x if summed_v is not None else x 60 | averaged_params[k] = summed_v / len(v) 61 | new_state['model'] = averaged_params 62 | return new_state 63 | 64 | 65 | def last_n_checkpoints(paths, n, update_based): 66 | assert len(paths) == 1 67 | path = paths[0] 68 | if update_based: 69 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 70 | else: 71 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 72 | files = os.listdir(path) 73 | 74 | entries = [] 75 | for f in files: 76 | m = pt_regexp.fullmatch(f) 77 | if m is not None: 78 | entries.append((int(m.group(1)), m.group(0))) 79 | if len(entries) < n: 80 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 81 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser( 86 | description='Tool to average the params of input checkpoints to ' 87 | 'produce a new checkpoint', 88 | ) 89 | 90 | parser.add_argument( 91 | '--inputs', 92 | required=True, 93 | nargs='+', 94 | help='Input checkpoint file paths.', 95 | ) 96 | parser.add_argument( 97 | '--output', 98 | required=True, 99 | metavar='FILE', 100 | help='Write the new checkpoint containing the averaged weights to this ' 101 | 'path.', 102 | ) 103 | num_group = parser.add_mutually_exclusive_group() 104 | num_group.add_argument( 105 | '--num-epoch-checkpoints', 106 | type=int, 107 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 108 | 'and average last this many of them.', 109 | ) 110 | num_group.add_argument( 111 | '--num-update-checkpoints', 112 | type=int, 113 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 114 | 'and average last this many of them.', 115 | ) 116 | args = parser.parse_args() 117 | print(args) 118 | 119 | num = None 120 | is_update_based = False 121 | if args.num_update_checkpoints is not None: 122 | num = args.num_update_checkpoints 123 | is_update_based = True 124 | elif args.num_epoch_checkpoints is not None: 125 | num = args.num_epoch_checkpoints 126 | 127 | if num is not None: 128 | args.inputs = last_n_checkpoints(args.inputs, num, is_update_based) 129 | print('averaging checkpoints: ', args.inputs) 130 | 131 | new_state = average_checkpoints(args.inputs) 132 | torch.save(new_state, args.output) 133 | print('Finished writing averaged checkpoint to {}.'.format(args.output)) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /encdec/scripts/build_sym_alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | """ 10 | Use this script in order to build symmetric alignments for your translation 11 | dataset. 12 | This script depends on fast_align and mosesdecoder tools. You will need to 13 | build those before running the script. 14 | fast_align: 15 | github: http://github.com/clab/fast_align 16 | instructions: follow the instructions in README.md 17 | mosesdecoder: 18 | github: http://github.com/moses-smt/mosesdecoder 19 | instructions: http://www.statmt.org/moses/?n=Development.GetStarted 20 | The script produces the following files under --output_dir: 21 | text.joined - concatenation of lines from the source_file and the 22 | target_file. 23 | align.forward - forward pass of fast_align. 24 | align.backward - backward pass of fast_align. 25 | aligned.sym_heuristic - symmetrized alignment. 26 | """ 27 | 28 | import argparse 29 | import os 30 | from itertools import zip_longest 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='symmetric alignment builer') 35 | parser.add_argument('--fast_align_dir', 36 | help='path to fast_align build directory') 37 | parser.add_argument('--mosesdecoder_dir', 38 | help='path to mosesdecoder root directory') 39 | parser.add_argument('--sym_heuristic', 40 | help='heuristic to use for symmetrization', 41 | default='grow-diag-final-and') 42 | parser.add_argument('--source_file', 43 | help='path to a file with sentences ' 44 | 'in the source language') 45 | parser.add_argument('--target_file', 46 | help='path to a file with sentences ' 47 | 'in the target language') 48 | parser.add_argument('--output_dir', 49 | help='output directory') 50 | args = parser.parse_args() 51 | 52 | fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align') 53 | symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal') 54 | sym_fast_align_bin = os.path.join( 55 | args.mosesdecoder_dir, 'scripts', 'ems', 56 | 'support', 'symmetrize-fast-align.perl') 57 | 58 | # create joined file 59 | joined_file = os.path.join(args.output_dir, 'text.joined') 60 | with open(args.source_file, 'r') as src, open(args.target_file, 'r') as tgt: 61 | with open(joined_file, 'w') as joined: 62 | for s, t in zip_longest(src, tgt): 63 | print('{} ||| {}'.format(s.strip(), t.strip()), file=joined) 64 | 65 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 66 | 67 | # run forward alignment 68 | fwd_align_file = os.path.join(args.output_dir, 'align.forward') 69 | fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format( 70 | FASTALIGN=fast_align_bin, 71 | JOINED=joined_file, 72 | FWD=fwd_align_file) 73 | assert os.system(fwd_fast_align_cmd) == 0 74 | 75 | # run backward alignment 76 | bwd_align_file = os.path.join(args.output_dir, 'align.backward') 77 | bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format( 78 | FASTALIGN=fast_align_bin, 79 | JOINED=joined_file, 80 | BWD=bwd_align_file) 81 | assert os.system(bwd_fast_align_cmd) == 0 82 | 83 | # run symmetrization 84 | sym_out_file = os.path.join(args.output_dir, 'aligned') 85 | sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format( 86 | SYMFASTALIGN=sym_fast_align_bin, 87 | FWD=fwd_align_file, 88 | BWD=bwd_align_file, 89 | SRC=args.source_file, 90 | TGT=args.target_file, 91 | OUT=sym_out_file, 92 | HEURISTIC=args.sym_heuristic, 93 | SYMAL=symal_bin 94 | ) 95 | assert os.system(sym_cmd) == 0 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /encdec/scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_dictionary.lua <dict.th7> 9 | require 'fairseq' 10 | require 'torch' 11 | require 'paths' 12 | 13 | if #arg < 1 then 14 | print('usage: convert_dictionary.lua <dict.th7>') 15 | os.exit(1) 16 | end 17 | if not paths.filep(arg[1]) then 18 | print('error: file does not exit: ' .. arg[1]) 19 | os.exit(1) 20 | end 21 | 22 | dict = torch.load(arg[1]) 23 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 24 | assert(dst:match('.txt$')) 25 | 26 | f = io.open(dst, 'w') 27 | for idx, symbol in ipairs(dict.index_to_symbol) do 28 | if idx > dict.cutoff then 29 | break 30 | end 31 | f:write(symbol) 32 | f:write(' ') 33 | f:write(dict.index_to_freq[idx]) 34 | f:write('\n') 35 | end 36 | f:close() 37 | -------------------------------------------------------------------------------- /encdec/scripts/convert_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) 2017-present, Facebook, Inc. 2 | -- All rights reserved. 3 | -- 4 | -- This source code is licensed under the license found in the LICENSE file in 5 | -- the root directory of this source tree. An additional grant of patent rights 6 | -- can be found in the PATENTS file in the same directory. 7 | -- 8 | -- Usage: convert_model.lua <model_epoch1.th7> 9 | require 'torch' 10 | local fairseq = require 'fairseq' 11 | 12 | model = torch.load(arg[1]) 13 | 14 | function find_weight_norm(container, module) 15 | for _, wn in ipairs(container:listModules()) do 16 | if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then 17 | return wn 18 | end 19 | end 20 | end 21 | 22 | function push_state(dict, key, module) 23 | if torch.type(module) == 'nn.Linear' then 24 | local wn = find_weight_norm(model.module, module) 25 | assert(wn) 26 | dict[key .. '.weight_v'] = wn.v:float() 27 | dict[key .. '.weight_g'] = wn.g:float() 28 | elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then 29 | local wn = find_weight_norm(model.module, module) 30 | assert(wn) 31 | local v = wn.v:float():view(wn.viewOut):transpose(2, 3) 32 | dict[key .. '.weight_v'] = v 33 | dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) 34 | else 35 | dict[key .. '.weight'] = module.weight:float() 36 | end 37 | if module.bias then 38 | dict[key .. '.bias'] = module.bias:float() 39 | end 40 | end 41 | 42 | encoder_dict = {} 43 | decoder_dict = {} 44 | combined_dict = {} 45 | 46 | function encoder_state(encoder) 47 | luts = encoder:findModules('nn.LookupTable') 48 | push_state(encoder_dict, 'embed_tokens', luts[1]) 49 | push_state(encoder_dict, 'embed_positions', luts[2]) 50 | 51 | fcs = encoder:findModules('nn.Linear') 52 | assert(#fcs >= 2) 53 | local nInputPlane = fcs[1].weight:size(1) 54 | push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) 55 | push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) 56 | 57 | for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do 58 | push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) 59 | if nInputPlane ~= module.weight:size(3) / 2 then 60 | push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 61 | end 62 | nInputPlane = module.weight:size(3) / 2 63 | end 64 | assert(#fcs == 0) 65 | end 66 | 67 | function decoder_state(decoder) 68 | luts = decoder:findModules('nn.LookupTable') 69 | push_state(decoder_dict, 'embed_tokens', luts[1]) 70 | push_state(decoder_dict, 'embed_positions', luts[2]) 71 | 72 | fcs = decoder:findModules('nn.Linear') 73 | local nInputPlane = fcs[1].weight:size(1) 74 | push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) 75 | push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) 76 | push_state(decoder_dict, 'fc3', fcs[#fcs]) 77 | 78 | table.remove(fcs, #fcs) 79 | table.remove(fcs, #fcs) 80 | 81 | for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do 82 | if nInputPlane ~= module.weight:size(3) / 2 then 83 | push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 84 | end 85 | nInputPlane = module.weight:size(3) / 2 86 | 87 | local prefix = 'attention.' .. tostring(i - 1) 88 | push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) 89 | push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) 90 | push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) 91 | end 92 | assert(#fcs == 0) 93 | end 94 | 95 | 96 | _encoder = model.module.modules[2] 97 | _decoder = model.module.modules[3] 98 | 99 | encoder_state(_encoder) 100 | decoder_state(_decoder) 101 | 102 | for k, v in pairs(encoder_dict) do 103 | combined_dict['encoder.' .. k] = v 104 | end 105 | for k, v in pairs(decoder_dict) do 106 | combined_dict['decoder.' .. k] = v 107 | end 108 | 109 | 110 | torch.save('state_dict.t7', combined_dict) 111 | -------------------------------------------------------------------------------- /encdec/scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | import argparse 11 | 12 | from fairseq.data import dictionary 13 | from fairseq.data import IndexedDataset 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser( 18 | description='writes text from binarized file to stdout') 19 | parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words') 20 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 21 | 22 | return parser 23 | 24 | 25 | def main(args): 26 | dict = dictionary.Dictionary.load(args.dict) 27 | ds = IndexedDataset(args.input, fix_lua_indexing=True) 28 | for tensor_line in ds: 29 | print(dict.string(tensor_line)) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = get_parser() 34 | args = parser.parse_args() 35 | main(args) 36 | -------------------------------------------------------------------------------- /encdec/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from setuptools import setup, find_packages, Extension 10 | import sys 11 | 12 | 13 | if sys.version_info < (3,): 14 | sys.exit('Sorry, Python3 is required for fairseq.') 15 | 16 | with open('README.md') as f: 17 | readme = f.read() 18 | 19 | with open('LICENSE') as f: 20 | license = f.read() 21 | 22 | with open('requirements.txt') as f: 23 | reqs = f.read() 24 | 25 | 26 | bleu = Extension( 27 | 'fairseq.libbleu', 28 | sources=[ 29 | 'fairseq/clib/libbleu/libbleu.cpp', 30 | 'fairseq/clib/libbleu/module.cpp', 31 | ], 32 | extra_compile_args=['-std=c++11'], 33 | ) 34 | 35 | 36 | setup( 37 | name='fairseq', 38 | version='0.6.0', 39 | description='Facebook AI Research Sequence-to-Sequence Toolkit', 40 | long_description=readme, 41 | license=license, 42 | install_requires=reqs.strip().split('\n'), 43 | packages=find_packages(), 44 | ext_modules=[bleu], 45 | test_suite='tests', 46 | ) 47 | -------------------------------------------------------------------------------- /encdec/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takase/control-length/c9eb632e112f07a156a5892065969cda788253a0/encdec/tests/__init__.py -------------------------------------------------------------------------------- /encdec/tests/test_average_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import collections 9 | import os 10 | import tempfile 11 | import unittest 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from scripts.average_checkpoints import average_checkpoints 17 | 18 | 19 | class TestAverageCheckpoints(unittest.TestCase): 20 | def test_average_checkpoints(self): 21 | params_0 = collections.OrderedDict( 22 | [ 23 | ('a', torch.DoubleTensor([100.0])), 24 | ('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])), 25 | ('c', torch.IntTensor([7, 8, 9])), 26 | ] 27 | ) 28 | params_1 = collections.OrderedDict( 29 | [ 30 | ('a', torch.DoubleTensor([1.0])), 31 | ('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])), 32 | ('c', torch.IntTensor([2, 2, 2])), 33 | ] 34 | ) 35 | params_avg = collections.OrderedDict( 36 | [ 37 | ('a', torch.DoubleTensor([50.5])), 38 | ('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])), 39 | # We expect truncation for integer division 40 | ('c', torch.IntTensor([4, 5, 5])), 41 | ] 42 | ) 43 | 44 | fd_0, path_0 = tempfile.mkstemp() 45 | fd_1, path_1 = tempfile.mkstemp() 46 | torch.save(collections.OrderedDict([('model', params_0)]), path_0) 47 | torch.save(collections.OrderedDict([('model', params_1)]), path_1) 48 | 49 | output = average_checkpoints([path_0, path_1])['model'] 50 | 51 | os.close(fd_0) 52 | os.remove(path_0) 53 | os.close(fd_1) 54 | os.remove(path_1) 55 | 56 | for (k_expected, v_expected), (k_out, v_out) in zip( 57 | params_avg.items(), output.items()): 58 | self.assertEqual( 59 | k_expected, k_out, 'Key mismatch - expected {} but found {}. ' 60 | '(Expected list of keys: {} vs actual list of keys: {})'.format( 61 | k_expected, k_out, params_avg.keys(), output.keys() 62 | ) 63 | ) 64 | np.testing.assert_allclose( 65 | v_expected.numpy(), 66 | v_out.numpy(), 67 | err_msg='Tensor value mismatch for key {}'.format(k_expected) 68 | ) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /encdec/tests/test_backtranslation_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import tests.utils as test_utils 11 | import torch 12 | from fairseq.data.backtranslation_dataset import BacktranslationDataset 13 | from fairseq import sequence_generator 14 | 15 | 16 | class TestBacktranslationDataset(unittest.TestCase): 17 | def setUp(self): 18 | self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = ( 19 | test_utils.sequence_generator_setup() 20 | ) 21 | 22 | dummy_src_samples = self.src_tokens 23 | 24 | self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) 25 | 26 | def _backtranslation_dataset_helper(self, remove_eos_at_src): 27 | """ 28 | SequenceGenerator kwargs are same as defaults from fairseq/options.py 29 | """ 30 | backtranslation_dataset = BacktranslationDataset( 31 | tgt_dataset=self.tgt_dataset, 32 | tgt_dict=self.tgt_dict, 33 | backtranslation_model=self.model, 34 | max_len_a=0, 35 | max_len_b=200, 36 | beam_size=2, 37 | unk_penalty=0, 38 | sampling=False, 39 | remove_eos_at_src=remove_eos_at_src, 40 | generator_class=sequence_generator.SequenceGenerator, 41 | ) 42 | dataloader = torch.utils.data.DataLoader( 43 | backtranslation_dataset, 44 | batch_size=2, 45 | collate_fn=backtranslation_dataset.collater, 46 | ) 47 | backtranslation_batch_result = next(iter(dataloader)) 48 | 49 | eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 50 | 51 | # Note that we sort by src_lengths and add left padding, so actually 52 | # ids will look like: [1, 0] 53 | expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) 54 | if remove_eos_at_src: 55 | expected_src = expected_src[:, :-1] 56 | expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) 57 | generated_src = backtranslation_batch_result["net_input"]["src_tokens"] 58 | tgt_tokens = backtranslation_batch_result["target"] 59 | 60 | self.assertTensorEqual(expected_src, generated_src) 61 | self.assertTensorEqual(expected_tgt, tgt_tokens) 62 | 63 | def test_backtranslation_dataset_no_eos_at_src(self): 64 | self._backtranslation_dataset_helper(remove_eos_at_src=True) 65 | 66 | def test_backtranslation_dataset_with_eos_at_src(self): 67 | self._backtranslation_dataset_helper(remove_eos_at_src=False) 68 | 69 | def assertTensorEqual(self, t1, t2): 70 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 71 | self.assertEqual(t1.ne(t2).long().sum(), 0) 72 | 73 | 74 | if __name__ == "__main__": 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /encdec/tests/test_character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | 11 | from fairseq.data import Dictionary 12 | from fairseq.modules import CharacterTokenEmbedder 13 | 14 | 15 | class TestCharacterTokenEmbedder(unittest.TestCase): 16 | def test_character_token_embedder(self): 17 | vocab = Dictionary() 18 | vocab.add_symbol('hello') 19 | vocab.add_symbol('there') 20 | 21 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2) 22 | 23 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] 24 | max_len = max(len(s) for s in test_sents) 25 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) 26 | for i in range(len(test_sents)): 27 | input[i][0] = vocab.eos() 28 | for j in range(len(test_sents[i])): 29 | input[i][j + 1] = vocab.index(test_sents[i][j]) 30 | input[i][j + 2] = vocab.eos() 31 | embs = embedder(input) 32 | 33 | assert embs.size() == (len(test_sents), max_len + 2, 5) 34 | self.assertAlmostEqual(embs[0][0], embs[1][0]) 35 | self.assertAlmostEqual(embs[0][0], embs[0][-1]) 36 | self.assertAlmostEqual(embs[0][1], embs[2][1]) 37 | self.assertAlmostEqual(embs[0][3], embs[1][1]) 38 | 39 | embs.sum().backward() 40 | assert embedder.char_embeddings.weight.grad is not None 41 | 42 | def assertAlmostEqual(self, t1, t2): 43 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 44 | self.assertLess((t1 - t2).abs().max(), 1e-6) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /encdec/tests/test_convtbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import unittest 10 | from fairseq.modules import ConvTBC 11 | import torch.nn as nn 12 | 13 | 14 | class TestConvTBC(unittest.TestCase): 15 | 16 | def test_convtbc(self): 17 | # ksz, in_channels, out_channels 18 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) 19 | # out_channels, in_channels, ksz 20 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1) 21 | 22 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) 23 | conv_tbc.bias.data.copy_(conv1d.bias.data) 24 | 25 | input_tbc = torch.randn(7, 2, 4, requires_grad=True) 26 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2) 27 | input1d.requires_grad = True 28 | 29 | output_tbc = conv_tbc(input_tbc) 30 | output1d = conv1d(input1d) 31 | 32 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) 33 | 34 | grad_tbc = torch.randn(output_tbc.size()) 35 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() 36 | 37 | output_tbc.backward(grad_tbc) 38 | output1d.backward(grad1d) 39 | 40 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data) 41 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) 42 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data) 43 | 44 | def assertAlmostEqual(self, t1, t2): 45 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 46 | self.assertLess((t1 - t2).abs().max(), 1e-4) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /encdec/tests/test_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import tempfile 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.data import Dictionary 14 | from fairseq.tokenizer import Tokenizer 15 | 16 | 17 | class TestDictionary(unittest.TestCase): 18 | 19 | def test_finalize(self): 20 | txt = [ 21 | 'A B C D', 22 | 'B C D', 23 | 'C D', 24 | 'D', 25 | ] 26 | ref_ids1 = list(map(torch.IntTensor, [ 27 | [4, 5, 6, 7, 2], 28 | [5, 6, 7, 2], 29 | [6, 7, 2], 30 | [7, 2], 31 | ])) 32 | ref_ids2 = list(map(torch.IntTensor, [ 33 | [7, 6, 5, 4, 2], 34 | [6, 5, 4, 2], 35 | [5, 4, 2], 36 | [4, 2], 37 | ])) 38 | 39 | # build dictionary 40 | d = Dictionary() 41 | for line in txt: 42 | Tokenizer.tokenize(line, d, add_if_not_exist=True) 43 | 44 | def get_ids(dictionary): 45 | ids = [] 46 | for line in txt: 47 | ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False)) 48 | return ids 49 | 50 | def assertMatch(ids, ref_ids): 51 | for toks, ref_toks in zip(ids, ref_ids): 52 | self.assertEqual(toks.size(), ref_toks.size()) 53 | self.assertEqual(0, (toks != ref_toks).sum().item()) 54 | 55 | ids = get_ids(d) 56 | assertMatch(ids, ref_ids1) 57 | 58 | # check finalized dictionary 59 | d.finalize() 60 | finalized_ids = get_ids(d) 61 | assertMatch(finalized_ids, ref_ids2) 62 | 63 | # write to disk and reload 64 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict: 65 | d.save(tmp_dict.name) 66 | d = Dictionary.load(tmp_dict.name) 67 | reload_ids = get_ids(d) 68 | assertMatch(reload_ids, ref_ids2) 69 | assertMatch(finalized_ids, reload_ids) 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /encdec/tests/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | from fairseq.data import iterators 11 | 12 | 13 | class TestIterators(unittest.TestCase): 14 | 15 | def test_counting_iterator(self): 16 | x = list(range(10)) 17 | itr = iterators.CountingIterator(x) 18 | self.assertTrue(itr.has_next()) 19 | self.assertEqual(next(itr), 0) 20 | self.assertEqual(next(itr), 1) 21 | itr.skip(3) 22 | self.assertEqual(next(itr), 5) 23 | itr.skip(3) 24 | self.assertEqual(next(itr), 9) 25 | self.assertFalse(itr.has_next()) 26 | 27 | 28 | if __name__ == '__main__': 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /encdec/tests/test_label_smoothing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import copy 10 | import unittest 11 | 12 | import torch 13 | 14 | from fairseq.criterions.cross_entropy import CrossEntropyCriterion 15 | from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion 16 | 17 | import tests.utils as test_utils 18 | 19 | 20 | class TestLabelSmoothing(unittest.TestCase): 21 | 22 | def setUp(self): 23 | # build dictionary 24 | self.d = test_utils.dummy_dictionary(3) 25 | vocab = len(self.d) 26 | self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens 27 | self.assertEqual(self.d.pad(), 1) 28 | self.assertEqual(self.d.eos(), 2) 29 | self.assertEqual(self.d.unk(), 3) 30 | pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841 31 | 32 | # build dataset 33 | self.data = [ 34 | # the first batch item has padding 35 | {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])}, 36 | {'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])}, 37 | ] 38 | self.sample = next(test_utils.dummy_dataloader(self.data)) 39 | 40 | # build model 41 | self.args = argparse.Namespace() 42 | self.args.sentence_avg = False 43 | self.args.probs = torch.FloatTensor([ 44 | # pad eos unk w1 w2 w3 45 | [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05], 46 | [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10], 47 | [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15], 48 | ]).unsqueeze(0).expand(2, 3, 7) # add batch dimension 49 | self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d) 50 | self.model = self.task.build_model(self.args) 51 | 52 | def test_nll_loss(self): 53 | self.args.label_smoothing = 0.1 54 | nll_crit = CrossEntropyCriterion(self.args, self.task) 55 | smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 56 | nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) 57 | smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) 58 | self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6) 59 | self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6) 60 | 61 | def test_padding(self): 62 | self.args.label_smoothing = 0.1 63 | crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 64 | loss, _, logging_output = crit(self.model, self.sample) 65 | 66 | def get_one_no_padding(idx): 67 | # create a new sample with just a single batch item so that there's 68 | # no padding 69 | sample1 = next(test_utils.dummy_dataloader([self.data[idx]])) 70 | args1 = copy.copy(self.args) 71 | args1.probs = args1.probs[idx, :, :].unsqueeze(0) 72 | model1 = self.task.build_model(args1) 73 | loss1, _, _ = crit(model1, sample1) 74 | return loss1 75 | 76 | loss1 = get_one_no_padding(0) 77 | loss2 = get_one_no_padding(1) 78 | self.assertAlmostEqual(loss, loss1 + loss2) 79 | 80 | def test_reduction(self): 81 | self.args.label_smoothing = 0.1 82 | crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 83 | loss, _, logging_output = crit(self.model, self.sample, reduce=True) 84 | unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False) 85 | self.assertAlmostEqual(loss, unreduced_loss.sum()) 86 | 87 | def test_zero_eps(self): 88 | self.args.label_smoothing = 0.0 89 | nll_crit = CrossEntropyCriterion(self.args, self.task) 90 | smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task) 91 | nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) 92 | smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) 93 | self.assertAlmostEqual(nll_loss, smooth_loss) 94 | 95 | def assertAlmostEqual(self, t1, t2): 96 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 97 | self.assertLess((t1 - t2).abs().max(), 1e-6) 98 | 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /encdec/tests/test_reproducibility.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import contextlib 9 | from io import StringIO 10 | import json 11 | import os 12 | import tempfile 13 | import unittest 14 | 15 | import torch 16 | 17 | from fairseq import options 18 | 19 | from . import test_binaries 20 | 21 | 22 | class TestReproducibility(unittest.TestCase): 23 | 24 | def _test_reproducibility(self, name, extra_flags=None): 25 | if extra_flags is None: 26 | extra_flags = [] 27 | 28 | with tempfile.TemporaryDirectory(name) as data_dir: 29 | with contextlib.redirect_stdout(StringIO()): 30 | test_binaries.create_dummy_data(data_dir) 31 | test_binaries.preprocess_translation_data(data_dir) 32 | 33 | # train epochs 1 and 2 together 34 | stdout = StringIO() 35 | with contextlib.redirect_stdout(stdout): 36 | test_binaries.train_translation_model( 37 | data_dir, 'fconv_iwslt_de_en', [ 38 | '--dropout', '0.0', 39 | '--log-format', 'json', 40 | '--log-interval', '1', 41 | '--max-epoch', '3', 42 | ] + extra_flags, 43 | ) 44 | stdout = stdout.getvalue() 45 | train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2]) 46 | 47 | # train epoch 2, resuming from previous checkpoint 1 48 | os.rename( 49 | os.path.join(data_dir, 'checkpoint1.pt'), 50 | os.path.join(data_dir, 'checkpoint_last.pt'), 51 | ) 52 | stdout = StringIO() 53 | with contextlib.redirect_stdout(stdout): 54 | test_binaries.train_translation_model( 55 | data_dir, 'fconv_iwslt_de_en', [ 56 | '--dropout', '0.0', 57 | '--log-format', 'json', 58 | '--log-interval', '1', 59 | '--max-epoch', '3', 60 | ] + extra_flags, 61 | ) 62 | stdout = stdout.getvalue() 63 | train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2]) 64 | 65 | def cast(s): 66 | return round(float(s), 3) 67 | 68 | for k in ['loss', 'ppl', 'num_updates', 'gnorm']: 69 | self.assertEqual(cast(train_log[k]), cast(train_res_log[k])) 70 | for k in ['valid_loss', 'valid_ppl', 'num_updates', 'best']: 71 | self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k])) 72 | 73 | def test_reproducibility(self): 74 | self._test_reproducibility('test_reproducibility') 75 | 76 | def test_reproducibility_fp16(self): 77 | self._test_reproducibility('test_reproducibility_fp16', [ 78 | '--fp16', 79 | '--fp16-init-scale', '4096', 80 | ]) 81 | 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /encdec/tests/test_sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import argparse 9 | import unittest 10 | 11 | import torch 12 | 13 | from fairseq.sequence_scorer import SequenceScorer 14 | 15 | import tests.utils as test_utils 16 | 17 | 18 | class TestSequenceScorer(unittest.TestCase): 19 | 20 | def test_sequence_scorer(self): 21 | # construct dummy dictionary 22 | d = test_utils.dummy_dictionary(vocab_size=2) 23 | self.assertEqual(d.pad(), 1) 24 | self.assertEqual(d.eos(), 2) 25 | self.assertEqual(d.unk(), 3) 26 | eos = d.eos() 27 | w1 = 4 28 | w2 = 5 29 | 30 | # construct dataloader 31 | data = [ 32 | { 33 | 'source': torch.LongTensor([w1, w2, eos]), 34 | 'target': torch.LongTensor([w1, w2, w1, eos]), 35 | }, 36 | { 37 | 'source': torch.LongTensor([w2, eos]), 38 | 'target': torch.LongTensor([w2, w1, eos]), 39 | }, 40 | { 41 | 'source': torch.LongTensor([w2, eos]), 42 | 'target': torch.LongTensor([w2, eos]), 43 | }, 44 | ] 45 | data_itr = test_utils.dummy_dataloader(data) 46 | 47 | # specify expected output probabilities 48 | args = argparse.Namespace() 49 | unk = 0. 50 | args.beam_probs = [ 51 | # step 0: 52 | torch.FloatTensor([ 53 | # eos w1 w2 54 | [0.0, unk, 0.6, 0.4], # sentence 1 55 | [0.0, unk, 0.4, 0.6], # sentence 2 56 | [0.0, unk, 0.7, 0.3], # sentence 3 57 | ]), 58 | # step 1: 59 | torch.FloatTensor([ 60 | # eos w1 w2 61 | [0.0, unk, 0.2, 0.7], # sentence 1 62 | [0.0, unk, 0.8, 0.2], # sentence 2 63 | [0.7, unk, 0.1, 0.2], # sentence 3 64 | ]), 65 | # step 2: 66 | torch.FloatTensor([ 67 | # eos w1 w2 68 | [0.10, unk, 0.50, 0.4], # sentence 1 69 | [0.15, unk, 0.15, 0.7], # sentence 2 70 | [0.00, unk, 0.00, 0.0], # sentence 3 71 | ]), 72 | # step 3: 73 | torch.FloatTensor([ 74 | # eos w1 w2 75 | [0.9, unk, 0.05, 0.05], # sentence 1 76 | [0.0, unk, 0.00, 0.0], # sentence 2 77 | [0.0, unk, 0.00, 0.0], # sentence 3 78 | ]), 79 | ] 80 | expected_scores = [ 81 | [0.6, 0.7, 0.5, 0.9], # sentence 1 82 | [0.6, 0.8, 0.15], # sentence 2 83 | [0.3, 0.7], # sentence 3 84 | ] 85 | 86 | task = test_utils.TestTranslationTask.setup_task(args, d, d) 87 | model = task.build_model(args) 88 | scorer = SequenceScorer([model], task.target_dictionary) 89 | for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr): 90 | self.assertHypoTokens(hypos[0], data[id]['target']) 91 | self.assertHypoScore(hypos[0], expected_scores[id]) 92 | 93 | def assertHypoTokens(self, hypo, tokens): 94 | self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens)) 95 | 96 | def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): 97 | pos_scores = torch.FloatTensor(pos_probs).log() 98 | self.assertAlmostEqual(hypo['positional_scores'], pos_scores) 99 | self.assertEqual(pos_scores.numel(), hypo['tokens'].numel()) 100 | score = pos_scores.sum() 101 | if normalized: 102 | score /= pos_scores.numel()**lenpen 103 | self.assertLess(abs(score - hypo['score']), 1e-6) 104 | 105 | def assertAlmostEqual(self, t1, t2): 106 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 107 | self.assertLess((t1 - t2).abs().max(), 1e-4) 108 | 109 | def assertTensorEqual(self, t1, t2): 110 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 111 | self.assertEqual(t1.ne(t2).long().sum(), 0) 112 | 113 | 114 | if __name__ == '__main__': 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /encdec/tests/test_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import contextlib 9 | from io import StringIO 10 | import unittest 11 | from unittest.mock import MagicMock, patch 12 | 13 | import torch 14 | 15 | from fairseq import data 16 | 17 | import train 18 | 19 | 20 | def mock_trainer(epoch, num_updates, iterations_in_epoch): 21 | trainer = MagicMock() 22 | trainer.load_checkpoint.return_value = { 23 | 'train_iterator': { 24 | 'epoch': epoch, 25 | 'iterations_in_epoch': iterations_in_epoch, 26 | 'shuffle': False, 27 | }, 28 | } 29 | trainer.get_num_updates.return_value = num_updates 30 | return trainer 31 | 32 | 33 | def mock_dict(): 34 | d = MagicMock() 35 | d.pad.return_value = 1 36 | d.eos.return_value = 2 37 | d.unk.return_value = 3 38 | return d 39 | 40 | 41 | def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): 42 | tokens = torch.LongTensor(list(range(epoch_size))) 43 | tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False) 44 | trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) 45 | dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) 46 | epoch_itr = data.EpochBatchIterator( 47 | dataset=dataset, 48 | collate_fn=dataset.collater, 49 | batch_sampler=[[i] for i in range(epoch_size)], 50 | ) 51 | return trainer, epoch_itr 52 | 53 | 54 | class TestLoadCheckpoint(unittest.TestCase): 55 | 56 | def setUp(self): 57 | self.args_mock = MagicMock() 58 | self.args_mock.optimizer_overrides = '{}' 59 | self.patches = { 60 | 'os.makedirs': MagicMock(), 61 | 'os.path.join': MagicMock(), 62 | 'os.path.isfile': MagicMock(return_value=True), 63 | } 64 | self.applied_patches = [patch(p, d) for p, d in self.patches.items()] 65 | [p.start() for p in self.applied_patches] 66 | 67 | 68 | def test_load_partial_checkpoint(self): 69 | with contextlib.redirect_stdout(StringIO()): 70 | trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) 71 | 72 | train.load_checkpoint(self.args_mock, trainer, epoch_itr) 73 | self.assertEqual(epoch_itr.epoch, 2) 74 | self.assertEqual(epoch_itr.iterations_in_epoch, 50) 75 | 76 | itr = epoch_itr.next_epoch_itr(shuffle=False) 77 | self.assertEqual(epoch_itr.epoch, 2) 78 | self.assertEqual(epoch_itr.iterations_in_epoch, 50) 79 | 80 | self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50) 81 | self.assertEqual(epoch_itr.iterations_in_epoch, 51) 82 | 83 | def test_load_full_checkpoint(self): 84 | with contextlib.redirect_stdout(StringIO()): 85 | trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) 86 | 87 | train.load_checkpoint(self.args_mock, trainer, epoch_itr) 88 | itr = epoch_itr.next_epoch_itr(shuffle=False) 89 | 90 | self.assertEqual(epoch_itr.epoch, 3) 91 | self.assertEqual(epoch_itr.iterations_in_epoch, 0) 92 | self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) 93 | 94 | def test_load_no_checkpoint(self): 95 | with contextlib.redirect_stdout(StringIO()): 96 | trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) 97 | self.patches['os.path.isfile'].return_value = False 98 | 99 | train.load_checkpoint(self.args_mock, trainer, epoch_itr) 100 | itr = epoch_itr.next_epoch_itr(shuffle=False) 101 | 102 | self.assertEqual(epoch_itr.epoch, 1) 103 | self.assertEqual(epoch_itr.iterations_in_epoch, 0) 104 | self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) 105 | 106 | def tearDown(self): 107 | patch.stopall() 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /encdec/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import unittest 9 | 10 | import torch 11 | 12 | from fairseq import utils 13 | 14 | 15 | class TestUtils(unittest.TestCase): 16 | 17 | def test_convert_padding_direction(self): 18 | pad = 1 19 | left_pad = torch.LongTensor([ 20 | [2, 3, 4, 5, 6], 21 | [1, 7, 8, 9, 10], 22 | [1, 1, 1, 11, 12], 23 | ]) 24 | right_pad = torch.LongTensor([ 25 | [2, 3, 4, 5, 6], 26 | [7, 8, 9, 10, 1], 27 | [11, 12, 1, 1, 1], 28 | ]) 29 | 30 | self.assertAlmostEqual( 31 | right_pad, 32 | utils.convert_padding_direction( 33 | left_pad, 34 | pad, 35 | left_to_right=True, 36 | ), 37 | ) 38 | self.assertAlmostEqual( 39 | left_pad, 40 | utils.convert_padding_direction( 41 | right_pad, 42 | pad, 43 | right_to_left=True, 44 | ), 45 | ) 46 | 47 | def test_make_positions(self): 48 | pad = 1 49 | left_pad_input = torch.LongTensor([ 50 | [9, 9, 9, 9, 9], 51 | [1, 9, 9, 9, 9], 52 | [1, 1, 1, 9, 9], 53 | ]) 54 | left_pad_output = torch.LongTensor([ 55 | [2, 3, 4, 5, 6], 56 | [1, 2, 3, 4, 5], 57 | [1, 1, 1, 2, 3], 58 | ]) 59 | right_pad_input = torch.LongTensor([ 60 | [9, 9, 9, 9, 9], 61 | [9, 9, 9, 9, 1], 62 | [9, 9, 1, 1, 1], 63 | ]) 64 | right_pad_output = torch.LongTensor([ 65 | [2, 3, 4, 5, 6], 66 | [2, 3, 4, 5, 1], 67 | [2, 3, 1, 1, 1], 68 | ]) 69 | 70 | self.assertAlmostEqual( 71 | left_pad_output, 72 | utils.make_positions(left_pad_input, pad, left_pad=True), 73 | ) 74 | self.assertAlmostEqual( 75 | right_pad_output, 76 | utils.make_positions(right_pad_input, pad, left_pad=False), 77 | ) 78 | 79 | def assertAlmostEqual(self, t1, t2): 80 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 81 | self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) 82 | 83 | 84 | if __name__ == '__main__': 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /eval/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For Neural Attention Model for Abstractive Summarization software 4 | 5 | Copyright (c) 2015-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | ## Calculate ROUGE 2 | 3 | - Obtain ROUGE-1.5.5.pl 4 | 5 | ```sh eval.sh DIR_REF&OUT SOURCE_PART LENGTH THIS_DIR ROUGE_DIR``` 6 | 7 | - DIR_REF&OUT: directory contains two directories, references and systems 8 | - references contains the correct outputs. File name is task1_ref0.txt (if there are multiple references, task1_ref0.txt, task1_ref1.txt, task1_ref2.txt, ...) 9 | - systems contains the system outputs. File name is task1_SPECIFIC_NAME_OF_YOUR_FILE.txt 10 | - SOURCE_PART: file contains source parts of references 11 | - LENGTH: desired length. In our paper, we used 30, 50, and 70 in English. 12 | - THIS_DIR: path to this directory 13 | - ROUGE_DIR: directory contains the rouge script, i.e., ROUGE-1.5.5.pl 14 | 15 | ## Calculate Variance 16 | 17 | ```python calculate_variance_from_fixlength.py -s SYS_OUTPUT -l LENGTH``` 18 | 19 | - SYS_OUTPUT: file contains outputs 20 | - LENGTH: desired length. In our paper, we used 30, 50, and 70 in English. -------------------------------------------------------------------------------- /eval/calculate_variance_from_fixlength.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #calculate variance of length (the number of characters) between system generation and given correct length 4 | #this variance calculation is based on one defined in http://aclweb.org/anthology/D18-1444 5 | 6 | 7 | import sys 8 | import collections 9 | import numpy as np 10 | import argparse 11 | 12 | 13 | def read_file(filename): 14 | return [line.strip() for line in open(filename)] 15 | 16 | 17 | def main(args): 18 | system_out = read_file(args.system_output) 19 | if args.reference: 20 | reference = read_file(args.reference) 21 | reference_len = [len(s) for s in reference] 22 | else: 23 | reference_len = [args.length for _ in range(len(system_out))] 24 | total = 0.0 25 | abs_diff = 0 26 | for reflen, sent in zip(reference_len, system_out): 27 | total += (len(sent) - reflen) ** 2 28 | abs_diff += abs(len(sent) - reflen) 29 | sntnum = len(reference_len) 30 | var = total / sntnum 31 | print('Variance: %.5f'%(var)) 32 | print('Ave diff: %.5f'%(float(abs_diff) / sntnum)) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('-s', '--system', dest='system_output', 38 | required=True, help='specify the system output file name') 39 | parser.add_argument('-r', '--reference', dest='reference', default='', 40 | help='specify the reference file name') 41 | parser.add_argument('-l', '--length', dest='length', type=int, default=75, 42 | help='the number of characters of correct sequence') 43 | args = parser.parse_args() 44 | main(args) 45 | -------------------------------------------------------------------------------- /eval/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #$1: directory contains references and systems 4 | #$2: the source part of test data 5 | #$3: the number of characters for ROUGE calculation (output is trimmed by this number) 6 | #$4: directory contains make_rouge.py and prepare4rouge-simple.pl 7 | #$5: path to ROUGE script (ROUGE-1.5.5.pl) 8 | 9 | export BASEDir=$4 10 | export ROUGE=$5 11 | cd $1 12 | rm -fr $1/tmp_GOLD 13 | rm -fr $1/tmp_SYSTEM 14 | rm -fr $1/tmp_OUTPUT 15 | mkdir -p $1/tmp_GOLD 16 | mkdir -p $1/tmp_SYSTEM 17 | 18 | python2.7 $BASEDir/DUC/make_rouge.py --base $1 --gold tmp_GOLD --system tmp_SYSTEM --input $2 19 | perl $BASEDir/DUC/prepare4rouge-simple.pl tmp_SYSTEM tmp_GOLD tmp_OUTPUT 20 | 21 | cd tmp_OUTPUT 22 | 23 | echo "LIMITED LENGTH" 24 | perl $ROUGE/ROUGE-1.5.5.pl -m -b $3 -n 2 -w 1.2 -e $ROUGE -a settings.xml 25 | -------------------------------------------------------------------------------- /eval/make_rouge.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2015, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | # Author: Alexander M Rush <srush@seas.harvard.edu> 10 | # Sumit Chopra <spchopra@fb.com> 11 | # Jason Weston <jase@fb.com> 12 | 13 | """Prep ROUGE eval. """ 14 | 15 | import sys 16 | import glob 17 | import os 18 | import argparse 19 | import itertools 20 | #@lint-avoid-python-3-compatibility-imports 21 | 22 | parser = argparse.ArgumentParser(description=__doc__, 23 | formatter_class= 24 | argparse.RawDescriptionHelpFormatter) 25 | parser.add_argument('--base', help="Base directory.", type=str) 26 | parser.add_argument('--gold', help="Base directory.", type=str) 27 | parser.add_argument('--system', help="Base directory.", type=str) 28 | parser.add_argument('--input', help="Input text.", type=str) 29 | 30 | args = parser.parse_args(sys.argv[1:]) 31 | 32 | for f in glob.glob("{0}/references/*".format(args.base)): 33 | task, ref = f.split("/")[-1].split("_") 34 | ref = int(ref.split(".")[0][-1]) 35 | 36 | for i, l in enumerate(open(f)): 37 | os.system("mkdir -p %s/%s%04d"%(args.gold, task, i)) 38 | with open("%s/%s%04d/%s%04d.%04d.gold" % (args.gold, task, i, task, i, ref), "w") as out: 39 | print >>out, l.strip() 40 | 41 | 42 | for f in glob.glob("{0}/system/*".format(args.base)): 43 | task, ref = f.split("/")[-1].split("_", 1) 44 | #if ref.startswith("ducsystem"): continue 45 | system = ref.split(".")[0] 46 | os.system("mkdir -p %s/%s"%(args.system, system)) 47 | for i, (l, input_line) in enumerate(itertools.izip(open(f), open(args.input))): 48 | words = [] 49 | numbers = dict([(len(w), w) for w in input_line.strip().split() if w[0].isdigit()]) 50 | for w in l.strip().split(): 51 | # Replace # with numbers from the input. 52 | if w[0] == "#" and len(w) in numbers: 53 | words.append(numbers[len(w)]) 54 | elif w == "<s>": 55 | continue 56 | else: 57 | words.append(w) 58 | 59 | with open("%s/%s/%s%04d.%s.system" % (args.system, system, task, i, system),"w") as out: 60 | if words: 61 | print >>out, " ".join(words) 62 | else: 63 | print >>out, "fail" 64 | -------------------------------------------------------------------------------- /eval/prepare4rouge-simple.pl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takase/control-length/c9eb632e112f07a156a5892065969cda788253a0/eval/prepare4rouge-simple.pl --------------------------------------------------------------------------------