├── .gitignore
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _static
│ └── theme_overrides.css
├── command_line_tools.rst
├── conf.py
├── criterions.rst
├── data.rst
├── docutils.conf
├── getting_started.rst
├── index.rst
├── lr_scheduler.rst
├── make.bat
├── models.rst
├── modules.rst
├── optim.rst
├── overview.rst
├── requirements.txt
├── tasks.rst
├── tutorial_classifying_names.rst
└── tutorial_simple_lstm.rst
├── eval_lm.py
├── fairseq
├── __init__.py
├── binarizer.py
├── bleu.py
├── checkpoint_utils.py
├── clib
│ ├── libbleu
│ │ ├── libbleu.cpp
│ │ └── module.cpp
│ └── libnat
│ │ └── edit_dist.cpp
├── criterions
│ ├── __init__.py
│ ├── adaptive_loss.py
│ ├── binary_cross_entropy.py
│ ├── composite_loss.py
│ ├── cross_entropy.py
│ ├── fairseq_criterion.py
│ ├── label_smoothed_cross_entropy.py
│ ├── label_smoothed_cross_entropy_with_alignment.py
│ ├── legacy_masked_lm.py
│ ├── masked_lm.py
│ ├── nat_loss.py
│ ├── sentence_prediction.py
│ └── sentence_ranking.py
├── data
│ ├── __init__.py
│ ├── audio
│ │ ├── __init__.py
│ │ └── raw_audio_dataset.py
│ ├── backtranslation_dataset.py
│ ├── base_wrapper_dataset.py
│ ├── colorize_dataset.py
│ ├── concat_dataset.py
│ ├── concat_sentences_dataset.py
│ ├── data_utils.py
│ ├── data_utils_fast.pyx
│ ├── dictionary.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── fastbpe.py
│ │ ├── gpt2_bpe.py
│ │ ├── gpt2_bpe_utils.py
│ │ ├── hf_bert_bpe.py
│ │ ├── moses_tokenizer.py
│ │ ├── nltk_tokenizer.py
│ │ ├── sentencepiece_bpe.py
│ │ ├── space_tokenizer.py
│ │ └── subword_nmt_bpe.py
│ ├── fairseq_dataset.py
│ ├── id_dataset.py
│ ├── indexed_dataset.py
│ ├── iterators.py
│ ├── language_pair_dataset.py
│ ├── legacy
│ │ ├── __init__.py
│ │ ├── block_pair_dataset.py
│ │ ├── masked_lm_dataset.py
│ │ └── masked_lm_dictionary.py
│ ├── list_dataset.py
│ ├── lm_context_window_dataset.py
│ ├── lru_cache_dataset.py
│ ├── mask_tokens_dataset.py
│ ├── monolingual_dataset.py
│ ├── multi_corpus_sampled_dataset.py
│ ├── nested_dictionary_dataset.py
│ ├── noising.py
│ ├── num_samples_dataset.py
│ ├── numel_dataset.py
│ ├── offset_tokens_dataset.py
│ ├── pad_dataset.py
│ ├── plasma_utils.py
│ ├── prepend_dataset.py
│ ├── prepend_token_dataset.py
│ ├── raw_label_dataset.py
│ ├── replace_dataset.py
│ ├── resampling_dataset.py
│ ├── round_robin_zip_datasets.py
│ ├── sharded_dataset.py
│ ├── sort_dataset.py
│ ├── strip_token_dataset.py
│ ├── subsample_dataset.py
│ ├── token_block_dataset.py
│ ├── token_block_utils_fast.pyx
│ ├── transform_eos_dataset.py
│ ├── transform_eos_lang_pair_dataset.py
│ └── truncate_dataset.py
├── distributed_utils.py
├── file_utils.py
├── hub_utils.py
├── iterative_refinement_generator.py
├── legacy_distributed_data_parallel.py
├── meters.py
├── models
│ ├── __init__.py
│ ├── cmlm_transformer.py
│ ├── composite_encoder.py
│ ├── distributed_fairseq_model.py
│ ├── fairseq_decoder.py
│ ├── fairseq_encoder.py
│ ├── fairseq_incremental_decoder.py
│ ├── fairseq_model.py
│ ├── fconv.py
│ ├── fconv_lm.py
│ ├── fconv_self_att.py
│ ├── insertion_transformer.py
│ ├── iterative_nonautoregressive_transformer.py
│ ├── levenshtein_transformer.py
│ ├── lightconv.py
│ ├── lightconv_lm.py
│ ├── lstm.py
│ ├── masked_lm.py
│ ├── model_utils.py
│ ├── multilingual_transformer.py
│ ├── nonautoregressive_ensembles.py
│ ├── nonautoregressive_transformer.py
│ ├── roberta
│ │ ├── __init__.py
│ │ ├── alignment_utils.py
│ │ ├── hub_interface.py
│ │ └── model.py
│ ├── transformer.py
│ ├── transformer_from_pretrained_xlm.py
│ ├── transformer_lm.py
│ └── wav2vec.py
├── modules
│ ├── __init__.py
│ ├── adaptive_input.py
│ ├── adaptive_softmax.py
│ ├── beamable_mm.py
│ ├── character_token_embedder.py
│ ├── conv_tbc.py
│ ├── cuda_utils.cu
│ ├── downsampled_multihead_attention.py
│ ├── dynamic_convolution.py
│ ├── dynamicconv_layer
│ │ ├── __init__.py
│ │ ├── cuda_function_gen.py
│ │ ├── dynamicconv_cuda.cpp
│ │ ├── dynamicconv_cuda.cuh
│ │ ├── dynamicconv_cuda_kernel.cu
│ │ ├── dynamicconv_layer.py
│ │ ├── dynamiconv_cpu.cpp
│ │ └── setup.py
│ ├── gelu.py
│ ├── grad_multiply.py
│ ├── highway.py
│ ├── layer_norm.py
│ ├── learned_positional_embedding.py
│ ├── lightconv_layer
│ │ ├── __init__.py
│ │ ├── cuda_function_gen.py
│ │ ├── lightconv_cuda.cpp
│ │ ├── lightconv_cuda.cuh
│ │ ├── lightconv_cuda_kernel.cu
│ │ ├── lightconv_layer.py
│ │ └── setup.py
│ ├── lightweight_convolution.py
│ ├── linearized_convolution.py
│ ├── logsumexp_moe.py
│ ├── mean_pool_gating_network.py
│ ├── multihead_attention.py
│ ├── positional_embedding.py
│ ├── scalar_bias.py
│ ├── sinusoidal_positional_embedding.py
│ ├── sparse_multihead_attention.py
│ ├── sparse_transformer_sentence_encoder.py
│ ├── sparse_transformer_sentence_encoder_layer.py
│ ├── transformer_layer.py
│ ├── transformer_sentence_encoder.py
│ ├── transformer_sentence_encoder_layer.py
│ ├── unfold.py
│ └── vggblock.py
├── optim
│ ├── __init__.py
│ ├── adadelta.py
│ ├── adafactor.py
│ ├── adagrad.py
│ ├── adam.py
│ ├── adamax.py
│ ├── bmuf.py
│ ├── fairseq_optimizer.py
│ ├── fp16_optimizer.py
│ ├── lr_scheduler
│ │ ├── __init__.py
│ │ ├── cosine_lr_scheduler.py
│ │ ├── fairseq_lr_scheduler.py
│ │ ├── fixed_schedule.py
│ │ ├── inverse_square_root_schedule.py
│ │ ├── polynomial_decay_schedule.py
│ │ ├── reduce_lr_on_plateau.py
│ │ ├── tri_stage_lr_scheduler.py
│ │ └── triangular_lr_scheduler.py
│ ├── nag.py
│ └── sgd.py
├── options.py
├── pdb.py
├── progress_bar.py
├── registry.py
├── search.py
├── sequence_generator.py
├── sequence_scorer.py
├── tasks
│ ├── __init__.py
│ ├── audio_pretraining.py
│ ├── cross_lingual_lm.py
│ ├── fairseq_task.py
│ ├── language_modeling.py
│ ├── legacy_masked_lm.py
│ ├── masked_lm.py
│ ├── multilingual_masked_lm.py
│ ├── multilingual_translation.py
│ ├── semisupervised_translation.py
│ ├── sentence_prediction.py
│ ├── sentence_ranking.py
│ ├── translation.py
│ ├── translation_from_pretrained_xlm.py
│ ├── translation_lev.py
│ └── translation_moe.py
├── tokenizer.py
├── trainer.py
└── utils.py
├── fairseq_cli
├── __init__.py
├── eval_lm.py
├── generate.py
├── interactive.py
├── preprocess.py
├── score.py
├── setup.py
└── train.py
├── generate.py
├── hubconf.py
├── importance.py
├── importanceModel.sh
├── importance_mask.py
├── interactive.py
├── multihead_attention.py
├── preprocess.py
├── run.sh
├── score.py
├── scripts
├── __init__.py
├── average_checkpoints.py
├── build_sym_alignment.py
├── compare_namespaces.py
├── compound_split_bleu.sh
├── convert_dictionary.lua
├── convert_model.lua
├── count_docs.py
├── read_binarized.py
├── rm_pt.py
├── sacrebleu_pregen.sh
├── shard_docs.py
├── split_train_valid_docs.py
├── spm_decode.py
├── spm_encode.py
├── spm_train.py
├── wav2vec_featurize.py
└── wav2vec_manifest.py
├── setup.py
├── test.sh
├── tests
├── __init__.py
├── speech_recognition
│ ├── __init__.py
│ ├── asr_test_base.py
│ ├── test_collaters.py
│ ├── test_cross_entropy.py
│ └── test_vggtransformer.py
├── test_average_checkpoints.py
├── test_backtranslation_dataset.py
├── test_binaries.py
├── test_bmuf.py
├── test_character_token_embedder.py
├── test_concat_dataset.py
├── test_convtbc.py
├── test_dictionary.py
├── test_iterators.py
├── test_label_smoothing.py
├── test_memory_efficient_fp16.py
├── test_multi_corpus_sampled_dataset.py
├── test_noising.py
├── test_reproducibility.py
├── test_resampling_dataset.py
├── test_sequence_generator.py
├── test_sequence_scorer.py
├── test_sparse_multihead_attention.py
├── test_token_block_dataset.py
├── test_train.py
├── test_utils.py
└── utils.py
├── train.py
└── validate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # JetBrains PyCharm IDE
2 | .idea/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # macOS dir files
13 | .DS_Store
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # Checkpoints
35 | checkpoints
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # dotenv
92 | .env
93 |
94 | # virtualenv
95 | .venv
96 | venv/
97 | ENV/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 |
112 | # Generated files
113 | /fairseq/temporal_convolution_tbc
114 | /fairseq/modules/*_layer/*_forward.cu
115 | /fairseq/modules/*_layer/*_backward.cu
116 |
117 | # data
118 | data-bin/
119 |
120 | # reranking
121 | /examples/reranking/rerank_data
122 |
123 | # Cython-generated C++ source files
124 | /fairseq/data/data_utils_fast.cpp
125 | /fairseq/data/token_block_utils_fast.cpp
126 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = fairseq
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/_static/theme_overrides.css:
--------------------------------------------------------------------------------
1 | .wy-table-responsive table td kbd {
2 | white-space: nowrap;
3 | }
4 | .wy-table-responsive table td {
5 | white-space: normal !important;
6 | }
7 | .wy-table-responsive {
8 | overflow: visible !important;
9 | }
10 |
--------------------------------------------------------------------------------
/docs/command_line_tools.rst:
--------------------------------------------------------------------------------
1 | .. _Command-line Tools:
2 |
3 | Command-line Tools
4 | ==================
5 |
6 | Fairseq provides several command-line tools for training and evaluating models:
7 |
8 | - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
9 | - :ref:`fairseq-train`: Train a new model on one or multiple GPUs
10 | - :ref:`fairseq-generate`: Translate pre-processed data with a trained model
11 | - :ref:`fairseq-interactive`: Translate raw text with a trained model
12 | - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
13 | - :ref:`fairseq-eval-lm`: Language model evaluation
14 |
15 |
16 | .. _fairseq-preprocess:
17 |
18 | fairseq-preprocess
19 | ~~~~~~~~~~~~~~~~~~
20 | .. automodule:: preprocess
21 |
22 | .. argparse::
23 | :module: fairseq.options
24 | :func: get_preprocessing_parser
25 | :prog: fairseq-preprocess
26 |
27 |
28 | .. _fairseq-train:
29 |
30 | fairseq-train
31 | ~~~~~~~~~~~~~
32 | .. automodule:: train
33 |
34 | .. argparse::
35 | :module: fairseq.options
36 | :func: get_training_parser
37 | :prog: fairseq-train
38 |
39 |
40 | .. _fairseq-generate:
41 |
42 | fairseq-generate
43 | ~~~~~~~~~~~~~~~~
44 | .. automodule:: generate
45 |
46 | .. argparse::
47 | :module: fairseq.options
48 | :func: get_generation_parser
49 | :prog: fairseq-generate
50 |
51 |
52 | .. _fairseq-interactive:
53 |
54 | fairseq-interactive
55 | ~~~~~~~~~~~~~~~~~~~
56 | .. automodule:: interactive
57 |
58 | .. argparse::
59 | :module: fairseq.options
60 | :func: get_interactive_generation_parser
61 | :prog: fairseq-interactive
62 |
63 |
64 | .. _fairseq-score:
65 |
66 | fairseq-score
67 | ~~~~~~~~~~~~~
68 | .. automodule:: score
69 |
70 | .. argparse::
71 | :module: fairseq_cli.score
72 | :func: get_parser
73 | :prog: fairseq-score
74 |
75 |
76 | .. _fairseq-eval-lm:
77 |
78 | fairseq-eval-lm
79 | ~~~~~~~~~~~~~~~
80 | .. automodule:: eval_lm
81 |
82 | .. argparse::
83 | :module: fairseq.options
84 | :func: get_eval_lm_parser
85 | :prog: fairseq-eval-lm
86 |
--------------------------------------------------------------------------------
/docs/criterions.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _Criterions:
5 |
6 | Criterions
7 | ==========
8 |
9 | Criterions compute the loss function given the model and batch, roughly::
10 |
11 | loss = criterion(model, batch)
12 |
13 | .. automodule:: fairseq.criterions
14 | :members:
15 |
16 | .. autoclass:: fairseq.criterions.FairseqCriterion
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21 | :members:
22 | :undoc-members:
23 | .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24 | :members:
25 | :undoc-members:
26 | .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27 | :members:
28 | :undoc-members:
29 | .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30 | :members:
31 | :undoc-members:
32 |
--------------------------------------------------------------------------------
/docs/data.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.data
5 |
6 | Data Loading and Utilities
7 | ==========================
8 |
9 | .. _datasets:
10 |
11 | Datasets
12 | --------
13 |
14 | **Datasets** define the data format and provide helpers for creating
15 | mini-batches.
16 |
17 | .. autoclass:: fairseq.data.FairseqDataset
18 | :members:
19 | .. autoclass:: fairseq.data.LanguagePairDataset
20 | :members:
21 | .. autoclass:: fairseq.data.MonolingualDataset
22 | :members:
23 |
24 | **Helper Datasets**
25 |
26 | These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
27 | provide additional functionality:
28 |
29 | .. autoclass:: fairseq.data.BacktranslationDataset
30 | :members:
31 | .. autoclass:: fairseq.data.ConcatDataset
32 | :members:
33 | .. autoclass:: fairseq.data.ResamplingDataset
34 | :members:
35 | .. autoclass:: fairseq.data.RoundRobinZipDatasets
36 | :members:
37 | .. autoclass:: fairseq.data.TransformEosDataset
38 | :members:
39 |
40 |
41 | Dictionary
42 | ----------
43 |
44 | .. autoclass:: fairseq.data.Dictionary
45 | :members:
46 |
47 |
48 | Iterators
49 | ---------
50 |
51 | .. autoclass:: fairseq.data.CountingIterator
52 | :members:
53 | .. autoclass:: fairseq.data.EpochBatchIterator
54 | :members:
55 | .. autoclass:: fairseq.data.GroupedIterator
56 | :members:
57 | .. autoclass:: fairseq.data.ShardedIterator
58 | :members:
59 |
--------------------------------------------------------------------------------
/docs/docutils.conf:
--------------------------------------------------------------------------------
1 | [writers]
2 | option-limit=0
3 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. fairseq documentation master file, created by
2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | :github_url: https://github.com/pytorch/fairseq
7 |
8 |
9 | fairseq documentation
10 | =====================
11 |
12 | Fairseq is a sequence modeling toolkit written in `PyTorch
13 | `_ that allows researchers and developers to
14 | train custom models for translation, summarization, language modeling and other
15 | text generation tasks.
16 |
17 | .. toctree::
18 | :maxdepth: 1
19 | :caption: Getting Started
20 |
21 | getting_started
22 | command_line_tools
23 |
24 | .. toctree::
25 | :maxdepth: 1
26 | :caption: Extending Fairseq
27 |
28 | overview
29 | tutorial_simple_lstm
30 | tutorial_classifying_names
31 |
32 | .. toctree::
33 | :maxdepth: 2
34 | :caption: Library Reference
35 |
36 | tasks
37 | models
38 | criterions
39 | optim
40 | lr_scheduler
41 | data
42 | modules
43 |
44 |
45 | Indices and tables
46 | ==================
47 |
48 | * :ref:`genindex`
49 | * :ref:`search`
50 |
--------------------------------------------------------------------------------
/docs/lr_scheduler.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _Learning Rate Schedulers:
5 |
6 | Learning Rate Schedulers
7 | ========================
8 |
9 | Learning Rate Schedulers update the learning rate over the course of training.
10 | Learning rates can be updated after each update via :func:`step_update` or at
11 | epoch boundaries via :func:`step`.
12 |
13 | .. automodule:: fairseq.optim.lr_scheduler
14 | :members:
15 |
16 | .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21 | :members:
22 | :undoc-members:
23 | .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24 | :members:
25 | :undoc-members:
26 | .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27 | :members:
28 | :undoc-members:
29 | .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30 | :members:
31 | :undoc-members:
32 | .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33 | :members:
34 | :undoc-members:
35 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=fairseq
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/models.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.models
5 |
6 | .. _Models:
7 |
8 | Models
9 | ======
10 |
11 | A Model defines the neural network's ``forward()`` method and encapsulates all
12 | of the learnable parameters in the network. Each model also provides a set of
13 | named *architectures* that define the precise network configuration (e.g.,
14 | embedding dimension, number of layers, etc.).
15 |
16 | Both the model type and architecture are selected via the ``--arch``
17 | command-line argument. Once selected, a model may expose additional command-line
18 | arguments for further configuration.
19 |
20 | .. note::
21 |
22 | All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
23 | :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
24 | stand-alone Module in other PyTorch code.
25 |
26 |
27 | Convolutional Neural Networks (CNN)
28 | -----------------------------------
29 |
30 | .. module:: fairseq.models.fconv
31 | .. autoclass:: fairseq.models.fconv.FConvModel
32 | :members:
33 | .. autoclass:: fairseq.models.fconv.FConvEncoder
34 | :members:
35 | :undoc-members:
36 | .. autoclass:: fairseq.models.fconv.FConvDecoder
37 | :members:
38 |
39 |
40 | Long Short-Term Memory (LSTM) networks
41 | --------------------------------------
42 |
43 | .. module:: fairseq.models.lstm
44 | .. autoclass:: fairseq.models.lstm.LSTMModel
45 | :members:
46 | .. autoclass:: fairseq.models.lstm.LSTMEncoder
47 | :members:
48 | .. autoclass:: fairseq.models.lstm.LSTMDecoder
49 | :members:
50 |
51 |
52 | Transformer (self-attention) networks
53 | -------------------------------------
54 |
55 | .. module:: fairseq.models.transformer
56 | .. autoclass:: fairseq.models.transformer.TransformerModel
57 | :members:
58 | .. autoclass:: fairseq.models.transformer.TransformerEncoder
59 | :members:
60 | .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
61 | :members:
62 | .. autoclass:: fairseq.models.transformer.TransformerDecoder
63 | :members:
64 | .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
65 | :members:
66 |
67 |
68 | Adding new models
69 | -----------------
70 |
71 | .. currentmodule:: fairseq.models
72 | .. autofunction:: fairseq.models.register_model
73 | .. autofunction:: fairseq.models.register_model_architecture
74 | .. autoclass:: fairseq.models.BaseFairseqModel
75 | :members:
76 | :undoc-members:
77 | .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
78 | :members:
79 | :undoc-members:
80 | .. autoclass:: fairseq.models.FairseqEncoderModel
81 | :members:
82 | :undoc-members:
83 | .. autoclass:: fairseq.models.FairseqLanguageModel
84 | :members:
85 | :undoc-members:
86 | .. autoclass:: fairseq.models.FairseqMultiModel
87 | :members:
88 | :undoc-members:
89 | .. autoclass:: fairseq.models.FairseqEncoder
90 | :members:
91 | .. autoclass:: fairseq.models.CompositeEncoder
92 | :members:
93 | .. autoclass:: fairseq.models.FairseqDecoder
94 | :members:
95 |
96 |
97 | .. _Incremental decoding:
98 |
99 | Incremental decoding
100 | --------------------
101 |
102 | .. autoclass:: fairseq.models.FairseqIncrementalDecoder
103 | :members:
104 | :undoc-members:
105 |
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | Modules
2 | =======
3 |
4 | Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5 | be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6 |
7 | .. automodule:: fairseq.modules
8 | :members:
9 | :undoc-members:
10 |
--------------------------------------------------------------------------------
/docs/optim.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _optimizers:
5 |
6 | Optimizers
7 | ==========
8 |
9 | Optimizers update the Model parameters based on the gradients.
10 |
11 | .. automodule:: fairseq.optim
12 | :members:
13 |
14 | .. autoclass:: fairseq.optim.FairseqOptimizer
15 | :members:
16 | :undoc-members:
17 |
18 | .. autoclass:: fairseq.optim.adadelta.Adadelta
19 | :members:
20 | :undoc-members:
21 | .. autoclass:: fairseq.optim.adagrad.Adagrad
22 | :members:
23 | :undoc-members:
24 | .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25 | :members:
26 | :undoc-members:
27 | .. autoclass:: fairseq.optim.adam.FairseqAdam
28 | :members:
29 | :undoc-members:
30 | .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31 | :members:
32 | :undoc-members:
33 | .. autoclass:: fairseq.optim.nag.FairseqNAG
34 | :members:
35 | :undoc-members:
36 | .. autoclass:: fairseq.optim.sgd.SGD
37 | :members:
38 | :undoc-members:
39 |
--------------------------------------------------------------------------------
/docs/overview.rst:
--------------------------------------------------------------------------------
1 | Overview
2 | ========
3 |
4 | Fairseq can be extended through user-supplied `plug-ins
5 | `_. We support five kinds of
6 | plug-ins:
7 |
8 | - :ref:`Models` define the neural network architecture and encapsulate all of the
9 | learnable parameters.
10 | - :ref:`Criterions` compute the loss function given the model outputs and targets.
11 | - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
12 | Datasets, initializing the Model/Criterion and calculating the loss.
13 | - :ref:`Optimizers` update the Model parameters based on the gradients.
14 | - :ref:`Learning Rate Schedulers` update the learning rate over the course of
15 | training.
16 |
17 | **Training Flow**
18 |
19 | Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
20 | fairseq implements the following high-level training flow::
21 |
22 | for epoch in range(num_epochs):
23 | itr = task.get_batch_iterator(task.dataset('train'))
24 | for num_updates, batch in enumerate(itr):
25 | task.train_step(batch, model, criterion, optimizer)
26 | average_and_clip_gradients()
27 | optimizer.step()
28 | lr_scheduler.step_update(num_updates)
29 | lr_scheduler.step(epoch)
30 |
31 | where the default implementation for ``task.train_step`` is roughly::
32 |
33 | def train_step(self, batch, model, criterion, optimizer):
34 | loss = criterion(model, batch)
35 | optimizer.backward(loss)
36 | return loss
37 |
38 | **Registering new plug-ins**
39 |
40 | New plug-ins are *registered* through a set of ``@register`` function
41 | decorators, for example::
42 |
43 | @register_model('my_lstm')
44 | class MyLSTM(FairseqEncoderDecoderModel):
45 | (...)
46 |
47 | Once registered, new plug-ins can be used with the existing :ref:`Command-line
48 | Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
49 | new plug-ins.
50 |
51 | **Loading plug-ins from another directory**
52 |
53 | New plug-ins can be defined in a custom module stored in the user system. In
54 | order to import the module, and make the plugin available to *fairseq*, the
55 | command line supports the ``--user-dir`` flag that can be used to specify a
56 | custom location for additional modules to load into *fairseq*.
57 |
58 | For example, assuming this directory tree::
59 |
60 | /home/user/my-module/
61 | └── __init__.py
62 |
63 | with ``__init__.py``::
64 |
65 | from fairseq.models import register_model_architecture
66 | from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
67 |
68 | @register_model_architecture('transformer', 'my_transformer')
69 | def transformer_mmt_big(args):
70 | transformer_vaswani_wmt_en_de_big(args)
71 |
72 | it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
73 |
74 | fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
75 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx<2.0
2 | sphinx-argparse
3 |
--------------------------------------------------------------------------------
/docs/tasks.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.tasks
5 |
6 | .. _Tasks:
7 |
8 | Tasks
9 | =====
10 |
11 | Tasks store dictionaries and provide helpers for loading/iterating over
12 | Datasets, initializing the Model/Criterion and calculating the loss.
13 |
14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a
15 | task may expose additional command-line arguments for further configuration.
16 |
17 | Example usage::
18 |
19 | # setup the task (e.g., load dictionaries)
20 | task = fairseq.tasks.setup_task(args)
21 |
22 | # build model and criterion
23 | model = task.build_model(args)
24 | criterion = task.build_criterion(args)
25 |
26 | # load datasets
27 | task.load_dataset('train')
28 | task.load_dataset('valid')
29 |
30 | # iterate over mini-batches of data
31 | batch_itr = task.get_batch_iterator(
32 | task.dataset('train'), max_tokens=4096,
33 | )
34 | for batch in batch_itr:
35 | # compute the loss
36 | loss, sample_size, logging_output = task.get_loss(
37 | model, criterion, batch,
38 | )
39 | loss.backward()
40 |
41 |
42 | Translation
43 | -----------
44 |
45 | .. autoclass:: fairseq.tasks.translation.TranslationTask
46 |
47 | .. _language modeling:
48 |
49 | Language Modeling
50 | -----------------
51 |
52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
53 |
54 |
55 | Adding new tasks
56 | ----------------
57 |
58 | .. autofunction:: fairseq.tasks.register_task
59 | .. autoclass:: fairseq.tasks.FairseqTask
60 | :members:
61 | :undoc-members:
62 |
--------------------------------------------------------------------------------
/fairseq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | __all__ = ['pdb']
7 | __version__ = '0.8.0'
8 |
9 | import fairseq.criterions # noqa
10 | import fairseq.models # noqa
11 | import fairseq.modules # noqa
12 | import fairseq.optim # noqa
13 | import fairseq.optim.lr_scheduler # noqa
14 | import fairseq.pdb # noqa
15 | import fairseq.tasks # noqa
16 |
--------------------------------------------------------------------------------
/fairseq/binarizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from collections import Counter
7 | import os
8 |
9 | from fairseq.tokenizer import tokenize_line
10 |
11 |
12 | def safe_readline(f):
13 | pos = f.tell()
14 | while True:
15 | try:
16 | return f.readline()
17 | except UnicodeDecodeError:
18 | pos -= 1
19 | f.seek(pos) # search where this character begins
20 |
21 |
22 | class Binarizer:
23 |
24 | @staticmethod
25 | def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False,
26 | offset=0, end=-1):
27 | nseq, ntok = 0, 0
28 | replaced = Counter()
29 |
30 | def replaced_consumer(word, idx):
31 | if idx == dict.unk_index and word != dict.unk_word:
32 | replaced.update([word])
33 |
34 | with open(filename, 'r', encoding='utf-8') as f:
35 | f.seek(offset)
36 | # next(f) breaks f.tell(), hence readline() must be used
37 | line = safe_readline(f)
38 | while line:
39 | if end > 0 and f.tell() > end:
40 | break
41 | ids = dict.encode_line(
42 | line=line,
43 | line_tokenizer=tokenize,
44 | add_if_not_exist=False,
45 | consumer=replaced_consumer,
46 | append_eos=append_eos,
47 | reverse_order=reverse_order,
48 | )
49 | nseq += 1
50 | ntok += len(ids)
51 | consumer(ids)
52 | line = f.readline()
53 | return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
54 |
55 | @staticmethod
56 | def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
57 | nseq = 0
58 |
59 | with open(filename, 'r') as f:
60 | f.seek(offset)
61 | line = safe_readline(f)
62 | while line:
63 | if end > 0 and f.tell() > end:
64 | break
65 | ids = alignment_parser(line)
66 | nseq += 1
67 | consumer(ids)
68 | line = f.readline()
69 | return {'nseq': nseq}
70 |
71 | @staticmethod
72 | def find_offsets(filename, num_chunks):
73 | with open(filename, 'r', encoding='utf-8') as f:
74 | size = os.fstat(f.fileno()).st_size
75 | chunk_size = size // num_chunks
76 | offsets = [0 for _ in range(num_chunks + 1)]
77 | for i in range(1, num_chunks):
78 | f.seek(chunk_size * i)
79 | safe_readline(f)
80 | offsets[i] = f.tell()
81 | return offsets
82 |
--------------------------------------------------------------------------------
/fairseq/clib/libbleu/module.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2017-present, Facebook, Inc.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 |
11 |
12 | static PyMethodDef method_def[] = {
13 | {NULL, NULL, 0, NULL}
14 | };
15 |
16 | static struct PyModuleDef module_def = {
17 | PyModuleDef_HEAD_INIT,
18 | "libbleu", /* name of module */
19 | NULL, /* module documentation, may be NULL */
20 | -1, /* size of per-interpreter state of the module,
21 | or -1 if the module keeps state in global variables. */
22 | method_def
23 | };
24 |
25 |
26 | #if PY_MAJOR_VERSION == 2
27 | PyMODINIT_FUNC init_libbleu()
28 | #else
29 | PyMODINIT_FUNC PyInit_libbleu()
30 | #endif
31 | {
32 | PyObject *m = PyModule_Create(&module_def);
33 | if (!m) {
34 | return NULL;
35 | }
36 | return m;
37 | }
38 |
--------------------------------------------------------------------------------
/fairseq/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from fairseq import registry
10 | from fairseq.criterions.fairseq_criterion import FairseqCriterion
11 |
12 |
13 | build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
14 | '--criterion',
15 | base_class=FairseqCriterion,
16 | default='cross_entropy',
17 | )
18 |
19 |
20 | # automatically import any Python files in the criterions/ directory
21 | for file in os.listdir(os.path.dirname(__file__)):
22 | if file.endswith('.py') and not file.startswith('_'):
23 | module = file[:file.find('.py')]
24 | importlib.import_module('fairseq.criterions.' + module)
25 |
--------------------------------------------------------------------------------
/fairseq/criterions/binary_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from fairseq import utils
11 |
12 | from . import FairseqCriterion, register_criterion
13 |
14 |
15 | @register_criterion('binary_cross_entropy')
16 | class BinaryCrossEntropyCriterion(FairseqCriterion):
17 |
18 | def __init__(self, args, task):
19 | super().__init__(args, task)
20 |
21 | def forward(self, model, sample, reduce=True):
22 | """Compute the loss for the given sample.
23 |
24 | Returns a tuple with three elements:
25 | 1) the loss
26 | 2) the sample size, which is used as the denominator for the gradient
27 | 3) logging outputs to display while training
28 | """
29 | net_output = model(**sample['net_input'])
30 | logits = model.get_logits(net_output).float()
31 | target = model.get_targets(sample, net_output, expand_steps=False).float()
32 |
33 | if hasattr(model, 'get_target_weights'):
34 | weights = model.get_target_weights(target, net_output)
35 | if torch.is_tensor(weights):
36 | weights = weights.float()
37 | else:
38 | weights = 1.
39 |
40 | loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False)
41 |
42 | loss = loss * weights
43 |
44 | if reduce:
45 | loss = loss.sum()
46 |
47 | sample_size = target.numel()
48 | logging_output = {
49 | 'loss': utils.item(loss.data) if reduce else loss.data,
50 | 'ntokens': sample_size,
51 | 'nsentences': logits.size(0),
52 | 'sample_size': sample_size,
53 | }
54 | return loss, sample_size, logging_output
55 |
56 | @staticmethod
57 | def aggregate_logging_outputs(logging_outputs):
58 | """Aggregate logging outputs from data parallel training."""
59 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
60 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
61 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
62 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
63 | agg_output = {
64 | 'loss': loss_sum / sample_size / math.log(2),
65 | 'ntokens': ntokens,
66 | 'nsentences': nsentences,
67 | 'sample_size': sample_size,
68 | }
69 | if sample_size != ntokens:
70 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
71 | return agg_output
72 |
--------------------------------------------------------------------------------
/fairseq/criterions/cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | import torch.nn.functional as F
8 |
9 | from fairseq import utils
10 |
11 | from . import FairseqCriterion, register_criterion
12 |
13 |
14 | @register_criterion('cross_entropy')
15 | class CrossEntropyCriterion(FairseqCriterion):
16 |
17 | def __init__(self, args, task):
18 | super().__init__(args, task)
19 |
20 | def forward(self, model, sample, reduce=True):
21 | """Compute the loss for the given sample.
22 |
23 | Returns a tuple with three elements:
24 | 1) the loss
25 | 2) the sample size, which is used as the denominator for the gradient
26 | 3) logging outputs to display while training
27 | """
28 | net_output = model(**sample['net_input'])
29 | loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
30 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
31 | logging_output = {
32 | 'loss': utils.item(loss.data) if reduce else loss.data,
33 | 'nll_loss': utils.item(loss.data) if reduce else loss.data,
34 | 'ntokens': sample['ntokens'],
35 | 'nsentences': sample['target'].size(0),
36 | 'sample_size': sample_size,
37 | }
38 | return loss, sample_size, logging_output
39 |
40 | def compute_loss(self, model, net_output, sample, reduce=True):
41 | lprobs = model.get_normalized_probs(net_output, log_probs=True)
42 | lprobs = lprobs.view(-1, lprobs.size(-1))
43 | target = model.get_targets(sample, net_output).view(-1)
44 | loss = F.nll_loss(
45 | lprobs,
46 | target,
47 | ignore_index=self.padding_idx,
48 | reduction='sum' if reduce else 'none',
49 | )
50 | return loss, loss
51 |
52 | @staticmethod
53 | def aggregate_logging_outputs(logging_outputs):
54 | """Aggregate logging outputs from data parallel training."""
55 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
56 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
57 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
58 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
59 | agg_output = {
60 | 'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
61 | 'ntokens': ntokens,
62 | 'nsentences': nsentences,
63 | 'sample_size': sample_size,
64 | }
65 | if sample_size != ntokens:
66 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
67 | return agg_output
68 |
--------------------------------------------------------------------------------
/fairseq/criterions/fairseq_criterion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torch.nn.modules.loss import _Loss
7 |
8 |
9 | class FairseqCriterion(_Loss):
10 |
11 | def __init__(self, args, task):
12 | super().__init__()
13 | self.args = args
14 | self.task = task
15 | self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add criterion-specific arguments to the parser."""
20 | pass
21 |
22 | @classmethod
23 | def build_criterion(cls, args, task):
24 | return cls(args, task)
25 |
26 | def forward(self, model, sample, reduce=True):
27 | """Compute the loss for the given sample.
28 |
29 | Returns a tuple with three elements:
30 | 1) the loss
31 | 2) the sample size, which is used as the denominator for the gradient
32 | 3) logging outputs to display while training
33 | """
34 | raise NotImplementedError
35 |
36 | @staticmethod
37 | def aggregate_logging_outputs(logging_outputs):
38 | """Aggregate logging outputs from data parallel training."""
39 | raise NotImplementedError
40 |
41 | @staticmethod
42 | def grad_denom(sample_sizes):
43 | """Compute the gradient denominator for a set of sample sizes."""
44 | return sum(sample_sizes)
45 |
--------------------------------------------------------------------------------
/fairseq/criterions/masked_lm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 | from fairseq import utils
12 |
13 | from . import FairseqCriterion, register_criterion
14 |
15 |
16 | @register_criterion('masked_lm')
17 | class MaskedLmLoss(FairseqCriterion):
18 | """
19 | Implementation for the loss used in masked language model (MLM) training.
20 | """
21 |
22 | def __init__(self, args, task):
23 | super().__init__(args, task)
24 |
25 | def forward(self, model, sample, reduce=True):
26 | """Compute the loss for the given sample.
27 | Returns a tuple with three elements:
28 | 1) the loss
29 | 2) the sample size, which is used as the denominator for the gradient
30 | 3) logging outputs to display while training
31 | """
32 | # compute MLM loss
33 | masked_tokens = sample['target'].ne(self.padding_idx)
34 | sample_size = masked_tokens.int().sum().item()
35 |
36 | # (Rare case) When all tokens are masked, the model results in empty
37 | # tensor and gives CUDA error.
38 | if sample_size == 0:
39 | masked_tokens = None
40 |
41 | logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
42 | targets = model.get_targets(sample, [logits])
43 |
44 | if sample_size != 0:
45 | targets = targets[masked_tokens]
46 |
47 | loss = F.nll_loss(
48 | F.log_softmax(
49 | logits.view(-1, logits.size(-1)),
50 | dim=-1,
51 | dtype=torch.float32,
52 | ),
53 | targets.view(-1),
54 | reduction='sum',
55 | ignore_index=self.padding_idx,
56 | )
57 | logging_output = {
58 | 'loss': utils.item(loss.data) if reduce else loss.data,
59 | 'nll_loss': utils.item(loss.data) if reduce else loss.data,
60 | 'ntokens': sample['ntokens'],
61 | 'nsentences': sample['nsentences'],
62 | 'sample_size': sample_size,
63 | }
64 | return loss, sample_size, logging_output
65 |
66 | @staticmethod
67 | def aggregate_logging_outputs(logging_outputs):
68 | """Aggregate logging outputs from data parallel training."""
69 | loss = sum(log.get('loss', 0) for log in logging_outputs)
70 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
71 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
72 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
73 |
74 | agg_output = {
75 | 'loss': loss / sample_size / math.log(2),
76 | 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
77 | 'ntokens': ntokens,
78 | 'nsentences': nsentences,
79 | 'sample_size': sample_size,
80 | }
81 | return agg_output
82 |
--------------------------------------------------------------------------------
/fairseq/data/audio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/NA-MNMT/120843696d6d9ae24b7ba12fbaf85158512ee714/fairseq/data/audio/__init__.py
--------------------------------------------------------------------------------
/fairseq/data/base_wrapper_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torch.utils.data.dataloader import default_collate
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class BaseWrapperDataset(FairseqDataset):
12 |
13 | def __init__(self, dataset):
14 | super().__init__()
15 | self.dataset = dataset
16 |
17 | def __getitem__(self, index):
18 | return self.dataset[index]
19 |
20 | def __len__(self):
21 | return len(self.dataset)
22 |
23 | def collater(self, samples):
24 | if hasattr(self.dataset, 'collater'):
25 | return self.dataset.collater(samples)
26 | else:
27 | return default_collate(samples)
28 |
29 | @property
30 | def sizes(self):
31 | return self.dataset.sizes
32 |
33 | def num_tokens(self, index):
34 | return self.dataset.num_tokens(index)
35 |
36 | def size(self, index):
37 | return self.dataset.size(index)
38 |
39 | def ordered_indices(self):
40 | return self.dataset.ordered_indices()
41 |
42 | @property
43 | def supports_prefetch(self):
44 | return getattr(self.dataset, 'supports_prefetch', False)
45 |
46 | def prefetch(self, indices):
47 | self.dataset.prefetch(indices)
48 |
49 | def set_epoch(self, epoch):
50 | super().set_epoch(epoch)
51 | if hasattr(self.dataset, 'set_epoch'):
52 | self.dataset.set_epoch(epoch)
53 |
--------------------------------------------------------------------------------
/fairseq/data/colorize_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class ColorizeDataset(BaseWrapperDataset):
12 | """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
13 | def __init__(self, dataset, color_getter):
14 | super().__init__(dataset)
15 | self.color_getter = color_getter
16 |
17 | def collater(self, samples):
18 | base_collate = super().collater(samples)
19 | if len(base_collate) > 0:
20 | base_collate["net_input"]["colors"] = torch.tensor(
21 | list(self.color_getter(self.dataset, s["id"]) for s in samples),
22 | dtype=torch.long,
23 | )
24 | return base_collate
25 |
--------------------------------------------------------------------------------
/fairseq/data/concat_sentences_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class ConcatSentencesDataset(FairseqDataset):
12 |
13 | def __init__(self, *datasets):
14 | super().__init__()
15 | self.datasets = datasets
16 | assert all(len(ds) == len(datasets[0]) for ds in datasets), \
17 | 'datasets must have the same length'
18 |
19 | def __getitem__(self, index):
20 | return torch.cat([ds[index] for ds in self.datasets])
21 |
22 | def __len__(self):
23 | return len(self.datasets[0])
24 |
25 | def collater(self, samples):
26 | return self.datasets[0].collater(samples)
27 |
28 | @property
29 | def sizes(self):
30 | return sum(ds.sizes for ds in self.datasets)
31 |
32 | def num_tokens(self, index):
33 | return sum(ds.num_tokens(index) for ds in self.datasets)
34 |
35 | def size(self, index):
36 | return sum(ds.size(index) for ds in self.datasets)
37 |
38 | def ordered_indices(self):
39 | return self.datasets[0].ordered_indices()
40 |
41 | @property
42 | def supports_prefetch(self):
43 | return any(
44 | getattr(ds, 'supports_prefetch', False) for ds in self.datasets
45 | )
46 |
47 | def prefetch(self, indices):
48 | for ds in self.datasets:
49 | if getattr(ds, 'supports_prefetch', False):
50 | ds.prefetch(indices)
51 |
--------------------------------------------------------------------------------
/fairseq/data/data_utils_fast.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level=3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 |
9 | cimport cython
10 | cimport numpy as np
11 |
12 | DTYPE = np.int64
13 | ctypedef np.int64_t DTYPE_t
14 |
15 |
16 | cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
17 | if len(batch) == 0:
18 | return 0
19 | if len(batch) == max_sentences:
20 | return 1
21 | if num_tokens > max_tokens:
22 | return 1
23 | return 0
24 |
25 |
26 | @cython.cdivision(True)
27 | cpdef list batch_by_size_fast(
28 | np.ndarray[DTYPE_t, ndim=1] indices,
29 | num_tokens_fn,
30 | long max_tokens,
31 | long max_sentences,
32 | int bsz_mult,
33 | ):
34 | cdef long sample_len = 0
35 | cdef list sample_lens = []
36 | cdef list batch = []
37 | cdef list batches = []
38 | cdef long mod_len
39 | cdef long i
40 | cdef long idx
41 | cdef long num_tokens
42 | cdef DTYPE_t[:] indices_view = indices
43 |
44 | for i in range(len(indices_view)):
45 | idx = indices_view[i]
46 | num_tokens = num_tokens_fn(idx)
47 | sample_lens.append(num_tokens)
48 | sample_len = max(sample_len, num_tokens)
49 |
50 | assert sample_len <= max_tokens, (
51 | "sentence at index {} of size {} exceeds max_tokens "
52 | "limit of {}!".format(idx, sample_len, max_tokens)
53 | )
54 | num_tokens = (len(batch) + 1) * sample_len
55 |
56 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
57 | mod_len = max(
58 | bsz_mult * (len(batch) // bsz_mult),
59 | len(batch) % bsz_mult,
60 | )
61 | batches.append(batch[:mod_len])
62 | batch = batch[mod_len:]
63 | sample_lens = sample_lens[mod_len:]
64 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
65 | batch.append(idx)
66 | if len(batch) > 0:
67 | batches.append(batch)
68 | return batches
69 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import importlib
8 | import os
9 |
10 | from fairseq import registry
11 |
12 |
13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry(
14 | '--tokenizer',
15 | default=None,
16 | )
17 |
18 |
19 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry(
20 | '--bpe',
21 | default=None,
22 | )
23 |
24 |
25 | # automatically import any Python files in the encoders/ directory
26 | for file in os.listdir(os.path.dirname(__file__)):
27 | if file.endswith('.py') and not file.startswith('_'):
28 | module = file[:file.find('.py')]
29 | importlib.import_module('fairseq.data.encoders.' + module)
30 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/fastbpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | @register_bpe('fastbpe')
11 | class fastBPE(object):
12 |
13 | @staticmethod
14 | def add_args(parser):
15 | # fmt: off
16 | parser.add_argument('--bpe-codes', type=str,
17 | help='path to fastBPE BPE')
18 | # fmt: on
19 |
20 | def __init__(self, args):
21 | if args.bpe_codes is None:
22 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
23 | codes = file_utils.cached_path(args.bpe_codes)
24 | try:
25 | import fastBPE
26 | self.bpe = fastBPE.fastBPE(codes)
27 | self.bpe_symbol = "@@ "
28 | except ImportError:
29 | raise ImportError('Please install fastBPE with: pip install fastBPE')
30 |
31 | def encode(self, x: str) -> str:
32 | return self.bpe.apply([x])[0]
33 |
34 | def decode(self, x: str) -> str:
35 | return (x + ' ').replace(self.bpe_symbol, '').rstrip()
36 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/gpt2_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 | from .gpt2_bpe_utils import get_encoder
10 |
11 |
12 | DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13 | DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14 |
15 |
16 | @register_bpe('gpt2')
17 | class GPT2BPE(object):
18 |
19 | @staticmethod
20 | def add_args(parser):
21 | # fmt: off
22 | parser.add_argument('--gpt2-encoder-json', type=str,
23 | default=DEFAULT_ENCODER_JSON,
24 | help='path to encoder.json')
25 | parser.add_argument('--gpt2-vocab-bpe', type=str,
26 | default=DEFAULT_VOCAB_BPE,
27 | help='path to vocab.bpe')
28 | # fmt: on
29 |
30 | def __init__(self, args):
31 | encoder_json = file_utils.cached_path(
32 | getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON)
33 | )
34 | vocab_bpe = file_utils.cached_path(
35 | getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE)
36 | )
37 | self.bpe = get_encoder(encoder_json, vocab_bpe)
38 |
39 | def encode(self, x: str) -> str:
40 | return ' '.join(map(str, self.bpe.encode(x)))
41 |
42 | def decode(self, x: str) -> str:
43 | return self.bpe.decode(map(int, x.split()))
44 |
45 | def is_beginning_of_word(self, x: str) -> bool:
46 | return self.decode(x).startswith(' ')
47 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/hf_bert_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_bpe
7 |
8 |
9 | @register_bpe('bert')
10 | class BertBPE(object):
11 |
12 | @staticmethod
13 | def add_args(parser):
14 | # fmt: off
15 | parser.add_argument('--bpe-cased', action='store_true',
16 | help='set for cased BPE',
17 | default=False)
18 | parser.add_argument('--bpe-vocab-file', type=str,
19 | help='bpe vocab file.')
20 | # fmt: on
21 |
22 | def __init__(self, args):
23 | try:
24 | from pytorch_transformers import BertTokenizer
25 | from pytorch_transformers.tokenization_utils import clean_up_tokenization
26 | except ImportError:
27 | raise ImportError(
28 | 'Please install 1.0.0 version of pytorch_transformers'
29 | 'with: pip install pytorch-transformers'
30 | )
31 |
32 | if 'bpe_vocab_file' in args:
33 | self.bert_tokenizer = BertTokenizer(
34 | args.bpe_vocab_file,
35 | do_lower_case=not args.bpe_cased
36 | )
37 | else:
38 | vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased'
39 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
40 | self.clean_up_tokenization = clean_up_tokenization
41 |
42 | def encode(self, x: str) -> str:
43 | return ' '.join(self.bert_tokenizer.tokenize(x))
44 |
45 | def decode(self, x: str) -> str:
46 | return self.clean_up_tokenization(
47 | self.bert_tokenizer.convert_tokens_to_string(x.split(' '))
48 | )
49 |
50 | def is_beginning_of_word(self, x: str) -> bool:
51 | return not x.startswith('##')
52 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/moses_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_tokenizer
7 |
8 |
9 | @register_tokenizer('moses')
10 | class MosesTokenizer(object):
11 |
12 | @staticmethod
13 | def add_args(parser):
14 | # fmt: off
15 | parser.add_argument('--moses-source-lang', metavar='SRC',
16 | help='source language')
17 | parser.add_argument('--moses-target-lang', metavar='TARGET',
18 | help='target language')
19 | parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
20 | help='don\'t apply dash split rules')
21 | parser.add_argument('--moses-no-escape', action='store_true', default=False,
22 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
23 | # fmt: on
24 |
25 | def __init__(self, args):
26 | self.args = args
27 |
28 | if getattr(args, 'moses_source_lang', None) is None:
29 | args.moses_source_lang = getattr(args, 'source_lang', 'en')
30 | if getattr(args, 'moses_target_lang', None) is None:
31 | args.moses_target_lang = getattr(args, 'target_lang', 'en')
32 |
33 | try:
34 | from sacremoses import MosesTokenizer, MosesDetokenizer
35 | self.tok = MosesTokenizer(args.moses_source_lang)
36 | self.detok = MosesDetokenizer(args.moses_target_lang)
37 | except ImportError:
38 | raise ImportError('Please install Moses tokenizer with: pip install sacremoses')
39 |
40 | def encode(self, x: str) -> str:
41 | return self.tok.tokenize(
42 | x,
43 | aggressive_dash_splits=(not self.args.moses_no_dash_splits),
44 | return_str=True,
45 | escape=(not self.args.moses_no_escape),
46 | )
47 |
48 | def decode(self, x: str) -> str:
49 | return self.detok.detokenize(x.split())
50 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/nltk_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_tokenizer
7 |
8 |
9 | @register_tokenizer('nltk')
10 | class NLTKTokenizer(object):
11 |
12 | def __init__(self, source_lang=None, target_lang=None):
13 | try:
14 | from nltk.tokenize import word_tokenize
15 | self.word_tokenize = word_tokenize
16 | except ImportError:
17 | raise ImportError('Please install nltk with: pip install nltk')
18 |
19 | def encode(self, x: str) -> str:
20 | return ' '.join(self.word_tokenize(x))
21 |
22 | def decode(self, x: str) -> str:
23 | return x
24 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/sentencepiece_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | @register_bpe('sentencepiece')
11 | class SentencepieceBPE(object):
12 |
13 | @staticmethod
14 | def add_args(parser):
15 | # fmt: off
16 | parser.add_argument('--sentencepiece-vocab', type=str,
17 | help='path to sentencepiece vocab')
18 | # fmt: on
19 |
20 | def __init__(self, args):
21 | vocab = file_utils.cached_path(args.sentencepiece_vocab)
22 | try:
23 | import sentencepiece as spm
24 | self.sp = spm.SentencePieceProcessor()
25 | self.sp.Load(vocab)
26 | except ImportError:
27 | raise ImportError('Please install sentencepiece with: pip install sentencepiece')
28 |
29 | def encode(self, x: str) -> str:
30 | return ' '.join(self.sp.EncodeAsPieces(x))
31 |
32 | def decode(self, x: str) -> str:
33 | return x.replace(' ', '').replace('\u2581', ' ').strip()
34 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/space_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import re
7 |
8 | from fairseq.data.encoders import register_tokenizer
9 |
10 |
11 | @register_tokenizer('space')
12 | class SpaceTokenizer(object):
13 |
14 | def __init__(self, source_lang=None, target_lang=None):
15 | self.space_tok = re.compile(r"\s+")
16 |
17 | def encode(self, x: str) -> str:
18 | return self.space_tok.sub(' ', x)
19 |
20 | def decode(self, x: str) -> str:
21 | return x
22 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/subword_nmt_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | @register_bpe('subword_nmt')
11 | class SubwordNMTBPE(object):
12 |
13 | @staticmethod
14 | def add_args(parser):
15 | # fmt: off
16 | parser.add_argument('--bpe-codes', type=str,
17 | help='path to subword NMT BPE')
18 | parser.add_argument('--bpe-separator', default='@@',
19 | help='BPE separator')
20 | # fmt: on
21 |
22 | def __init__(self, args):
23 | if args.bpe_codes is None:
24 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
25 | codes = file_utils.cached_path(args.bpe_codes)
26 | try:
27 | from subword_nmt import apply_bpe
28 | bpe_parser = apply_bpe.create_parser()
29 | bpe_args = bpe_parser.parse_args([
30 | '--codes', codes,
31 | '--separator', args.bpe_separator,
32 | ])
33 | self.bpe = apply_bpe.BPE(
34 | bpe_args.codes,
35 | bpe_args.merges,
36 | bpe_args.separator,
37 | None,
38 | bpe_args.glossaries,
39 | )
40 | self.bpe_symbol = bpe_args.separator + ' '
41 | except ImportError:
42 | raise ImportError('Please install subword_nmt with: pip install subword-nmt')
43 |
44 | def encode(self, x: str) -> str:
45 | return self.bpe.process_line(x)
46 |
47 | def decode(self, x: str) -> str:
48 | return (x + ' ').replace(self.bpe_symbol, '').rstrip()
49 |
--------------------------------------------------------------------------------
/fairseq/data/fairseq_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch.utils.data
8 |
9 |
10 | class FairseqDataset(torch.utils.data.Dataset):
11 | """A dataset that provides helpers for batching."""
12 |
13 | def __getitem__(self, index):
14 | raise NotImplementedError
15 |
16 | def __len__(self):
17 | raise NotImplementedError
18 |
19 | def collater(self, samples):
20 | """Merge a list of samples to form a mini-batch.
21 |
22 | Args:
23 | samples (List[dict]): samples to collate
24 |
25 | Returns:
26 | dict: a mini-batch suitable for forwarding with a Model
27 | """
28 | raise NotImplementedError
29 |
30 | def num_tokens(self, index):
31 | """Return the number of tokens in a sample. This value is used to
32 | enforce ``--max-tokens`` during batching."""
33 | raise NotImplementedError
34 |
35 | def size(self, index):
36 | """Return an example's size as a float or tuple. This value is used when
37 | filtering a dataset with ``--max-positions``."""
38 | raise NotImplementedError
39 |
40 | def ordered_indices(self):
41 | """Return an ordered list of indices. Batches will be constructed based
42 | on this order."""
43 | return np.arange(len(self))
44 |
45 | @property
46 | def supports_prefetch(self):
47 | """Whether this dataset supports prefetching."""
48 | return False
49 |
50 | def attr(self, attr: str, index: int):
51 | return getattr(self, attr, None)
52 |
53 | def prefetch(self, indices):
54 | """Prefetch the data required for this epoch."""
55 | raise NotImplementedError
56 |
57 | def set_epoch(self, epoch):
58 | pass
59 |
--------------------------------------------------------------------------------
/fairseq/data/id_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class IdDataset(FairseqDataset):
12 |
13 | def __getitem__(self, index):
14 | return index
15 |
16 | def __len__(self):
17 | return 0
18 |
19 | def collater(self, samples):
20 | return torch.tensor(samples)
21 |
--------------------------------------------------------------------------------
/fairseq/data/legacy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
7 | from .block_pair_dataset import BlockPairDataset
8 | from .masked_lm_dataset import MaskedLMDataset
9 |
10 | __all__ = [
11 | 'BertDictionary',
12 | 'BlockPairDataset',
13 | 'MaskedLMDataset',
14 | 'MaskedLMDictionary',
15 | ]
16 |
--------------------------------------------------------------------------------
/fairseq/data/legacy/masked_lm_dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data import Dictionary
7 |
8 |
9 | class MaskedLMDictionary(Dictionary):
10 | """
11 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by
12 | adding the mask symbol.
13 | """
14 | def __init__(
15 | self,
16 | pad='',
17 | eos='',
18 | unk='',
19 | mask='',
20 | ):
21 | super().__init__(pad, eos, unk)
22 | self.mask_word = mask
23 | self.mask_index = self.add_symbol(mask)
24 | self.nspecial = len(self.symbols)
25 |
26 | def mask(self):
27 | """Helper to get index of mask symbol"""
28 | return self.mask_index
29 |
30 |
31 | class BertDictionary(MaskedLMDictionary):
32 | """
33 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support
34 | for cls and sep symbols.
35 | """
36 | def __init__(
37 | self,
38 | pad='',
39 | eos='',
40 | unk='',
41 | mask='',
42 | cls='',
43 | sep=''
44 | ):
45 | super().__init__(pad, eos, unk, mask)
46 | self.cls_word = cls
47 | self.sep_word = sep
48 | self.cls_index = self.add_symbol(cls)
49 | self.sep_index = self.add_symbol(sep)
50 | self.nspecial = len(self.symbols)
51 |
52 | def cls(self):
53 | """Helper to get index of cls symbol"""
54 | return self.cls_index
55 |
56 | def sep(self):
57 | """Helper to get index of sep symbol"""
58 | return self.sep_index
59 |
--------------------------------------------------------------------------------
/fairseq/data/list_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class ListDataset(BaseWrapperDataset):
10 |
11 | def __init__(self, dataset, sizes=None):
12 | super().__init__(dataset)
13 | self._sizes = sizes
14 |
15 | def collater(self, samples):
16 | return samples
17 |
18 | @property
19 | def sizes(self):
20 | return self._sizes
21 |
22 | def num_tokens(self, index):
23 | return self.sizes[index]
24 |
25 | def size(self, index):
26 | return self.sizes[index]
27 |
28 | def set_epoch(self, epoch):
29 | pass
30 |
--------------------------------------------------------------------------------
/fairseq/data/lm_context_window_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from fairseq.data.monolingual_dataset import MonolingualDataset
10 |
11 | from . import FairseqDataset
12 |
13 |
14 | class LMContextWindowDataset(FairseqDataset):
15 | """Wraps a MonolingualDataset and provides more context for evaluation."""
16 |
17 | def __init__(self, dataset, tokens_per_sample, context_window, pad_idx):
18 | assert isinstance(dataset, MonolingualDataset)
19 | assert context_window > 0
20 | self.dataset = dataset
21 | self.tokens_per_sample = tokens_per_sample
22 | self.context_window = context_window
23 | self.pad_idx = pad_idx
24 | self.prev_tokens = np.empty([0])
25 |
26 | def __getitem__(self, index):
27 | return self.dataset[index]
28 |
29 | def __len__(self):
30 | return len(self.dataset)
31 |
32 | def collater(self, samples):
33 | sample = self.dataset.collater(samples)
34 |
35 | pad = self.pad_idx
36 | max_sample_len = self.tokens_per_sample + self.context_window
37 |
38 | bsz, tsz = sample['net_input']['src_tokens'].shape
39 | start_idxs = [0] * bsz
40 | toks = sample['net_input']['src_tokens']
41 | lengths = sample['net_input']['src_lengths']
42 | tgt = sample['target']
43 | new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
44 | new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
45 | sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
46 | for i in range(bsz):
47 | sample_len = sample_lens[i]
48 | extra = len(self.prev_tokens) + sample_len - max_sample_len
49 | if extra > 0:
50 | self.prev_tokens = self.prev_tokens[extra:]
51 | pads = np.full(self.context_window - len(self.prev_tokens), pad)
52 | new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
53 | new_tgt[i, len(self.prev_tokens):len(self.prev_tokens) + len(tgt[i])] = tgt[i]
54 | start_idxs[i] = len(self.prev_tokens)
55 | lengths[i] += len(self.prev_tokens)
56 | self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window:]
57 | sample['net_input']['src_tokens'] = torch.from_numpy(new_toks)
58 | sample['target'] = torch.from_numpy(new_tgt)
59 | sample['start_indices'] = start_idxs
60 |
61 | return sample
62 |
63 | def num_tokens(self, index):
64 | return self.dataset.num_tokens(index)
65 |
66 | def size(self, index):
67 | return self.dataset.size(index)
68 |
69 | def ordered_indices(self):
70 | # NOTE we don't shuffle the data to retain access to the previous dataset elements
71 | return np.arange(len(self.dataset))
72 |
73 | @property
74 | def supports_prefetch(self):
75 | return getattr(self.dataset, 'supports_prefetch', False)
76 |
77 | def prefetch(self, indices):
78 | return self.dataset.prefetch(indices)
79 |
--------------------------------------------------------------------------------
/fairseq/data/lru_cache_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from functools import lru_cache
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class LRUCacheDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, token=None):
14 | super().__init__(dataset)
15 |
16 | @lru_cache(maxsize=8)
17 | def __getitem__(self, index):
18 | return self.dataset[index]
19 |
20 | @lru_cache(maxsize=8)
21 | def collater(self, samples):
22 | return self.dataset.collater(samples)
23 |
--------------------------------------------------------------------------------
/fairseq/data/num_samples_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import FairseqDataset
7 |
8 |
9 | class NumSamplesDataset(FairseqDataset):
10 |
11 | def __getitem__(self, index):
12 | return 1
13 |
14 | def __len__(self):
15 | return 0
16 |
17 | def collater(self, samples):
18 | return sum(samples)
19 |
--------------------------------------------------------------------------------
/fairseq/data/numel_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class NumelDataset(BaseWrapperDataset):
13 |
14 | def __init__(self, dataset, reduce=False):
15 | super().__init__(dataset)
16 | self.reduce = reduce
17 |
18 | def __getitem__(self, index):
19 | item = self.dataset[index]
20 | if torch.is_tensor(item):
21 | return torch.numel(item)
22 | else:
23 | return np.size(item)
24 |
25 | def __len__(self):
26 | return len(self.dataset)
27 |
28 | def collater(self, samples):
29 | if self.reduce:
30 | return sum(samples)
31 | else:
32 | return torch.tensor(samples)
33 |
--------------------------------------------------------------------------------
/fairseq/data/offset_tokens_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class OffsetTokensDataset(BaseWrapperDataset):
10 |
11 | def __init__(self, dataset, offset):
12 | super().__init__(dataset)
13 | self.offset = offset
14 |
15 | def __getitem__(self, idx):
16 | return self.dataset[idx] + self.offset
17 |
--------------------------------------------------------------------------------
/fairseq/data/pad_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data import data_utils
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class PadDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, pad_idx, left_pad):
14 | super().__init__(dataset)
15 | self.pad_idx = pad_idx
16 | self.left_pad = left_pad
17 |
18 | def collater(self, samples):
19 | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
20 |
21 |
22 | class LeftPadDataset(PadDataset):
23 |
24 | def __init__(self, dataset, pad_idx):
25 | super().__init__(dataset, pad_idx, left_pad=True)
26 |
27 |
28 | class RightPadDataset(PadDataset):
29 |
30 | def __init__(self, dataset, pad_idx):
31 | super().__init__(dataset, pad_idx, left_pad=False)
32 |
--------------------------------------------------------------------------------
/fairseq/data/plasma_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import subprocess
7 | import tempfile
8 |
9 |
10 | class PlasmaArray(object):
11 | """
12 | Wrapper around numpy arrays that automatically moves the data to shared
13 | memory upon serialization. This is particularly helpful when passing numpy
14 | arrays through multiprocessing, so that data is not unnecessarily
15 | duplicated or pickled.
16 | """
17 |
18 | def __init__(self, array):
19 | super().__init__()
20 | self.array = array
21 | self.disable = array.nbytes < 134217728 # disable for arrays <128MB
22 | self.object_id = None
23 | self.path = None
24 |
25 | # variables with underscores shouldn't be pickled
26 | self._client = None
27 | self._server = None
28 | self._server_tmp = None
29 | self._plasma = None
30 |
31 | @property
32 | def plasma(self):
33 | if self._plasma is None and not self.disable:
34 | try:
35 | import pyarrow.plasma as plasma
36 | self._plasma = plasma
37 | except ImportError:
38 | self._plasma = None
39 | return self._plasma
40 |
41 | def start_server(self):
42 | if self.plasma is None or self._server is not None:
43 | return
44 | assert self.object_id is None
45 | assert self.path is None
46 | self._server_tmp = tempfile.NamedTemporaryFile()
47 | self.path = self._server_tmp.name
48 | self._server = subprocess.Popen([
49 | 'plasma_store',
50 | '-m', str(int(1.05 * self.array.nbytes)),
51 | '-s', self.path,
52 | ])
53 |
54 | @property
55 | def client(self):
56 | if self._client is None:
57 | assert self.path is not None
58 | self._client = self.plasma.connect(self.path)
59 | return self._client
60 |
61 | def __getstate__(self):
62 | if self.plasma is None:
63 | return self.__dict__
64 | if self.object_id is None:
65 | self.start_server()
66 | self.object_id = self.client.put(self.array)
67 | state = self.__dict__.copy()
68 | del state['array']
69 | state['_client'] = None
70 | state['_server'] = None
71 | state['_server_tmp'] = None
72 | state['_plasma'] = None
73 | return state
74 |
75 | def __setstate__(self, state):
76 | self.__dict__.update(state)
77 | if self.plasma is None:
78 | return
79 | self.array = self.client.get(self.object_id)
80 |
81 | def __del__(self):
82 | if self._server is not None:
83 | self._server.kill()
84 | self._server = None
85 | self._server_tmp.close()
86 | self._server_tmp = None
87 |
--------------------------------------------------------------------------------
/fairseq/data/prepend_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class PrependDataset(BaseWrapperDataset):
13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
14 | super().__init__(dataset)
15 | self.prepend_getter = prepend_getter
16 | self.ensure_first_token = ensure_first_token_is
17 |
18 | def __getitem__(self, idx):
19 | item = self.dataset[idx]
20 | is_tuple = isinstance(item, tuple)
21 | src = item[0] if is_tuple else item
22 |
23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token
24 | prepend_idx = self.prepend_getter(self.dataset, idx)
25 | assert isinstance(prepend_idx, int)
26 | src[0] = prepend_idx
27 | item = tuple((src,) + item[1:]) if is_tuple else src
28 | return item
29 |
--------------------------------------------------------------------------------
/fairseq/data/prepend_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class PrependTokenDataset(BaseWrapperDataset):
13 |
14 | def __init__(self, dataset, token=None):
15 | super().__init__(dataset)
16 | self.token = token
17 | if token is not None:
18 | self._sizes = np.array(dataset.sizes) + 1
19 | else:
20 | self._sizes = dataset.sizes
21 |
22 | def __getitem__(self, idx):
23 | item = self.dataset[idx]
24 | if self.token is not None:
25 | item = torch.cat([item.new([self.token]), item])
26 | return item
27 |
28 | @property
29 | def sizes(self):
30 | return self._sizes
31 |
32 | def num_tokens(self, index):
33 | n = self.dataset.num_tokens(index)
34 | if self.token is not None:
35 | n += 1
36 | return n
37 |
38 | def size(self, index):
39 | n = self.dataset.size(index)
40 | if self.token is not None:
41 | n += 1
42 | return n
43 |
--------------------------------------------------------------------------------
/fairseq/data/raw_label_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class RawLabelDataset(FairseqDataset):
12 |
13 | def __init__(self, labels):
14 | super().__init__()
15 | self.labels = labels
16 |
17 | def __getitem__(self, index):
18 | return self.labels[index]
19 |
20 | def __len__(self):
21 | return len(self.labels)
22 |
23 | def collater(self, samples):
24 | return torch.tensor(samples)
25 |
--------------------------------------------------------------------------------
/fairseq/data/replace_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class ReplaceDataset(BaseWrapperDataset):
10 | """Replaces tokens found in the dataset by a specified replacement token
11 |
12 | Args:
13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in
14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token
15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
16 | as many as the number of objects returned by the underlying dataset __getitem__ method.
17 | """
18 |
19 | def __init__(self, dataset, replace_map, offsets):
20 | super().__init__(dataset)
21 | assert len(replace_map) > 0
22 | self.replace_map = replace_map
23 | self.offsets = offsets
24 |
25 | def __getitem__(self, index):
26 | item = self.dataset[index]
27 | is_tuple = isinstance(item, tuple)
28 | srcs = item if is_tuple else [item]
29 |
30 | for offset, src in zip(self.offsets, srcs):
31 | for k, v in self.replace_map.items():
32 | src_off = src[offset:] if offset >= 0 else src[:offset]
33 | src_off.masked_fill_(src_off == k, v)
34 |
35 | item = srcs if is_tuple else srcs[0]
36 | return item
37 |
--------------------------------------------------------------------------------
/fairseq/data/sharded_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import itertools
7 | import os
8 | import random
9 |
10 | from . import BaseWrapperDataset
11 | from fairseq.data import data_utils
12 |
13 |
14 | class ShardedDataset(BaseWrapperDataset):
15 | """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
16 |
17 | Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch
18 |
19 | """
20 |
21 | def __init__(
22 | self,
23 | dictionary,
24 | dataset_impl: str,
25 | path: str,
26 | split: str,
27 | epoch: int,
28 | name: str = None,
29 | combine: bool = False,
30 | seed: int = 0,
31 | ):
32 | self._name = name if name is not None else os.path.basename(path)
33 | num_shards = 0
34 | for i in itertools.count():
35 | if not os.path.exists(os.path.join(path, "shard" + str(i))):
36 | break
37 | num_shards += 1
38 |
39 | if num_shards > 0 and split == "train":
40 | random.seed(seed ^ epoch)
41 | shard = random.randint(0, num_shards - 1)
42 | split_path = os.path.join(path, "shard" + str(shard), split)
43 | else:
44 | split_path = os.path.join(path, split)
45 | if os.path.isdir(split_path):
46 | split_path = os.path.join(split_path, split)
47 |
48 | dataset = data_utils.load_indexed_dataset(
49 | split_path, dictionary, dataset_impl, combine=combine
50 | )
51 | if dataset is None:
52 | raise FileNotFoundError(
53 | "Dataset not found: {} ({})".format(split, split_path)
54 | )
55 |
56 | super().__init__(dataset)
57 |
58 | @property
59 | def name(self):
60 | return self._name
61 |
--------------------------------------------------------------------------------
/fairseq/data/sort_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class SortDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, sort_order):
14 | super().__init__(dataset)
15 | if not isinstance(sort_order, (list, tuple)):
16 | sort_order = [sort_order]
17 | self.sort_order = sort_order
18 |
19 | assert all(len(so) == len(dataset) for so in sort_order)
20 |
21 | def ordered_indices(self):
22 | return np.lexsort(self.sort_order)
23 |
--------------------------------------------------------------------------------
/fairseq/data/strip_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class StripTokenDataset(BaseWrapperDataset):
10 |
11 | def __init__(self, dataset, id_to_strip):
12 | super().__init__(dataset)
13 | self.id_to_strip = id_to_strip
14 |
15 | def __getitem__(self, index):
16 | item = self.dataset[index]
17 | return item[item.ne(self.id_to_strip)]
18 |
--------------------------------------------------------------------------------
/fairseq/data/subsample_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class SubsampleDataset(BaseWrapperDataset):
12 | """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
13 |
14 | Args:
15 | dataset (~torch.utils.data.Dataset): dataset to subsample
16 | size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
17 | """
18 |
19 | def __init__(self, dataset, size_ratio):
20 | super().__init__(dataset)
21 | assert size_ratio < 1
22 | self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
23 | self.indices = np.random.choice(
24 | list(range(len(self.dataset))), self.actual_size, replace=False
25 | )
26 | print(
27 | "subsampled dataset from {} to {} (ratio={})".format(
28 | len(self.dataset), self.actual_size, size_ratio
29 | )
30 | )
31 |
32 | def __getitem__(self, index):
33 | return self.dataset[self.indices[index]]
34 |
35 | def __len__(self):
36 | return self.actual_size
37 |
38 | def collater(self, samples):
39 | return self.dataset.collater(samples)
40 |
41 | @property
42 | def sizes(self):
43 | return self.dataset.sizes[self.indices]
44 |
45 | @property
46 | def name(self):
47 | return self.dataset.name
48 |
49 | def num_tokens(self, index):
50 | return self.dataset.num_tokens(self.indices[index])
51 |
52 | def size(self, index):
53 | return self.dataset.size(self.indices[index])
54 |
55 | def ordered_indices(self):
56 | """Return an ordered list of indices. Batches will be constructed based
57 | on this order."""
58 | if self.shuffle:
59 | order = [np.random.permutation(len(self))]
60 | else:
61 | order = [np.arange(len(self))]
62 | order.append(self.sizes)
63 | return np.lexsort(order)
64 |
65 | def prefetch(self, indices):
66 | self.dataset.prefetch(self.indices[indices])
67 |
--------------------------------------------------------------------------------
/fairseq/data/transform_eos_lang_pair_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from . import FairseqDataset
8 | from typing import Optional
9 |
10 |
11 | class TransformEosLangPairDataset(FairseqDataset):
12 | """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
13 | collated samples of language pair dataset.
14 |
15 | Note that the transformation is applied in :func:`collater`.
16 |
17 | Args:
18 | dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
19 | LanguagePairDataset schema
20 | src_eos (int): original source end-of-sentence symbol index to be replaced
21 | new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
22 | tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
23 | new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
24 | beginning of 'prev_output_tokens'
25 | """
26 |
27 | def __init__(
28 | self,
29 | dataset: FairseqDataset,
30 | src_eos: int,
31 | new_src_eos: Optional[int] = None,
32 | tgt_bos: Optional[int] = None,
33 | new_tgt_bos: Optional[int] = None,
34 | ):
35 | self.dataset = dataset
36 | self.src_eos = src_eos
37 | self.new_src_eos = new_src_eos
38 | self.tgt_bos = tgt_bos
39 | self.new_tgt_bos = new_tgt_bos
40 |
41 | def __getitem__(self, index):
42 | return self.dataset[index]
43 |
44 | def __len__(self):
45 | return len(self.dataset)
46 |
47 | def collater(self, samples):
48 | samples = self.dataset.collater(samples)
49 |
50 | # TODO: support different padding direction
51 | if self.new_src_eos is not None:
52 | assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0
53 | samples['net_input']['src_tokens'][:, -1] = self.new_src_eos
54 |
55 | if self.new_tgt_bos is not None:
56 | assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0
57 | samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos
58 |
59 | return samples
60 |
61 | def num_tokens(self, index):
62 | return self.dataset.num_tokens(index)
63 |
64 | def size(self, index):
65 | return self.dataset.size(index)
66 |
67 | def ordered_indices(self):
68 | return self.dataset.ordered_indices()
69 |
70 | @property
71 | def supports_prefetch(self):
72 | return getattr(self.dataset, 'supports_prefetch', False)
73 |
74 | def prefetch(self, indices):
75 | return self.dataset.prefetch(indices)
76 |
--------------------------------------------------------------------------------
/fairseq/data/truncate_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class TruncateDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, truncation_length):
14 | super().__init__(dataset)
15 | assert truncation_length is not None
16 | self.truncation_length = truncation_length
17 | self.dataset = dataset
18 |
19 | def __getitem__(self, index):
20 | item = self.dataset[index]
21 | item_len = item.size(0)
22 | if item_len > self.truncation_length:
23 | item = item[:self.truncation_length]
24 | return item
25 |
26 | @property
27 | def sizes(self):
28 | return np.minimum(self.dataset.sizes, self.truncation_length)
29 |
30 | def __len__(self):
31 | return len(self.dataset)
32 |
--------------------------------------------------------------------------------
/fairseq/meters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import time
7 |
8 |
9 | class AverageMeter(object):
10 | """Computes and stores the average and current value"""
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 |
27 | class TimeMeter(object):
28 | """Computes the average occurrence of some event per second"""
29 | def __init__(self, init=0):
30 | self.reset(init)
31 |
32 | def reset(self, init=0):
33 | self.init = init
34 | self.start = time.time()
35 | self.n = 0
36 |
37 | def update(self, val=1):
38 | self.n += val
39 |
40 | @property
41 | def avg(self):
42 | return self.n / self.elapsed_time
43 |
44 | @property
45 | def elapsed_time(self):
46 | return self.init + (time.time() - self.start)
47 |
48 |
49 | class StopwatchMeter(object):
50 | """Computes the sum/avg duration of some event in seconds"""
51 | def __init__(self):
52 | self.reset()
53 |
54 | def start(self):
55 | self.start_time = time.time()
56 |
57 | def stop(self, n=1):
58 | if self.start_time is not None:
59 | delta = time.time() - self.start_time
60 | self.sum += delta
61 | self.n += n
62 | self.start_time = None
63 |
64 | def reset(self):
65 | self.sum = 0
66 | self.n = 0
67 | self.start_time = None
68 |
69 | @property
70 | def avg(self):
71 | return self.sum / self.n
72 |
--------------------------------------------------------------------------------
/fairseq/models/composite_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.models import FairseqEncoder
7 |
8 |
9 | class CompositeEncoder(FairseqEncoder):
10 | """
11 | A wrapper around a dictionary of :class:`FairseqEncoder` objects.
12 |
13 | We run forward on each encoder and return a dictionary of outputs. The first
14 | encoder's dictionary is used for initialization.
15 |
16 | Args:
17 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
18 | """
19 |
20 | def __init__(self, encoders):
21 | super().__init__(next(iter(encoders.values())).dictionary)
22 | self.encoders = encoders
23 | for key in self.encoders:
24 | self.add_module(key, self.encoders[key])
25 |
26 | def forward(self, src_tokens, src_lengths):
27 | """
28 | Args:
29 | src_tokens (LongTensor): tokens in the source language of shape
30 | `(batch, src_len)`
31 | src_lengths (LongTensor): lengths of each source sentence of shape
32 | `(batch)`
33 |
34 | Returns:
35 | dict:
36 | the outputs from each Encoder
37 | """
38 | encoder_out = {}
39 | for key in self.encoders:
40 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
41 | return encoder_out
42 |
43 | def reorder_encoder_out(self, encoder_out, new_order):
44 | """Reorder encoder output according to new_order."""
45 | for key in self.encoders:
46 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order)
47 | return encoder_out
48 |
49 | def max_positions(self):
50 | return min([self.encoders[key].max_positions() for key in self.encoders])
51 |
52 | def upgrade_state_dict(self, state_dict):
53 | for key in self.encoders:
54 | self.encoders[key].upgrade_state_dict(state_dict)
55 | return state_dict
56 |
--------------------------------------------------------------------------------
/fairseq/models/distributed_fairseq_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import inspect
7 |
8 | import torch.nn as nn
9 |
10 | from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
11 | from fairseq.models import BaseFairseqModel
12 |
13 |
14 | def DistributedFairseqModel(args, model):
15 | """
16 | Wrap a *model* to support distributed data parallel training.
17 |
18 | This is similar to the built-in DistributedDataParallel, but allows
19 | additional configuration of the DistributedDataParallel class to
20 | use, and also provides easier access to the wrapped model by
21 | forwarding requests for missing attributes to the wrapped model.
22 |
23 | Args:
24 | args (argparse.Namespace): fairseq args
25 | model (BaseFairseqModel): model to wrap
26 | """
27 | # determine which DDP class to extend
28 | assert isinstance(model, nn.Module)
29 | if args.ddp_backend == 'c10d':
30 | ddp_class = nn.parallel.DistributedDataParallel
31 | init_kwargs = dict(
32 | module=model,
33 | device_ids=[args.device_id],
34 | output_device=args.device_id,
35 | broadcast_buffers=False,
36 | bucket_cap_mb=args.bucket_cap_mb,
37 | )
38 | # Maintain backward compatibility
39 | if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
40 | init_kwargs['check_reduction'] = True
41 | if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]:
42 | init_kwargs['find_unused_parameters'] = args.find_unused_parameters
43 | elif args.ddp_backend == 'no_c10d':
44 | ddp_class = LegacyDistributedDataParallel
45 | init_kwargs = dict(
46 | module=model,
47 | world_size=args.distributed_world_size,
48 | buffer_size=2**28,
49 | )
50 | else:
51 | raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
52 |
53 | class _DistributedFairseqModel(ddp_class):
54 | """Extend DistributedDataParallel to check for missing
55 | attributes in the wrapped module."""
56 |
57 | def __init__(self, *args, **kwargs):
58 | super().__init__(*args, **kwargs)
59 |
60 | def __getattr__(self, name):
61 | wrapped_module = super().__getattr__('module')
62 | if hasattr(wrapped_module, name):
63 | return getattr(wrapped_module, name)
64 | return super().__getattr__(name)
65 |
66 | return _DistributedFairseqModel(**init_kwargs)
67 |
--------------------------------------------------------------------------------
/fairseq/models/fairseq_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 | from fairseq import utils
9 |
10 |
11 | class FairseqDecoder(nn.Module):
12 | """Base class for decoders."""
13 |
14 | def __init__(self, dictionary):
15 | super().__init__()
16 | self.dictionary = dictionary
17 | self.onnx_trace = False
18 |
19 | def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
20 | """
21 | Args:
22 | prev_output_tokens (LongTensor): shifted output tokens of shape
23 | `(batch, tgt_len)`, for teacher forcing
24 | encoder_out (dict, optional): output from the encoder, used for
25 | encoder-side attention
26 |
27 | Returns:
28 | tuple:
29 | - the decoder's output of shape `(batch, tgt_len, vocab)`
30 | - a dictionary with any model-specific outputs
31 | """
32 | x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
33 | x = self.output_layer(x)
34 | return x, extra
35 |
36 | def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
37 | """
38 | Returns:
39 | tuple:
40 | - the decoder's features of shape `(batch, tgt_len, embed_dim)`
41 | - a dictionary with any model-specific outputs
42 | """
43 | raise NotImplementedError
44 |
45 | def output_layer(self, features, **kwargs):
46 | """
47 | Project features to the default output size, e.g., vocabulary size.
48 |
49 | Args:
50 | features (Tensor): features returned by *extract_features*.
51 | """
52 | raise NotImplementedError
53 |
54 | def get_normalized_probs(self, net_output, log_probs, sample):
55 | """Get normalized probabilities (or log probs) from a net's output."""
56 |
57 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
58 | if sample is not None:
59 | assert 'target' in sample
60 | target = sample['target']
61 | else:
62 | target = None
63 | out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
64 | return out.exp_() if not log_probs else out
65 |
66 | logits = net_output[0]
67 | # If you have bugs in computing importacne value,
68 | # comment the previous line and uncomment the next line.
69 | #logits = net_output['lang'][0] # 'lang' is one of your training language pairs.
70 |
71 | if log_probs:
72 | return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
73 | else:
74 | return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
75 |
76 | def max_positions(self):
77 | """Maximum input length supported by the decoder."""
78 | return 1e6 # an arbitrary large number
79 |
80 | def upgrade_state_dict(self, state_dict):
81 | """Upgrade a (possibly old) state dict for new versions of fairseq."""
82 | return state_dict
83 |
84 | def prepare_for_onnx_export_(self):
85 | self.onnx_trace = True
86 |
--------------------------------------------------------------------------------
/fairseq/models/fairseq_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 |
9 | class FairseqEncoder(nn.Module):
10 | """Base class for encoders."""
11 |
12 | def __init__(self, dictionary):
13 | super().__init__()
14 | self.dictionary = dictionary
15 |
16 | def forward(self, src_tokens, src_lengths=None, **kwargs):
17 | """
18 | Args:
19 | src_tokens (LongTensor): tokens in the source language of shape
20 | `(batch, src_len)`
21 | src_lengths (LongTensor): lengths of each source sentence of shape
22 | `(batch)`
23 | """
24 | raise NotImplementedError
25 |
26 | def reorder_encoder_out(self, encoder_out, new_order):
27 | """
28 | Reorder encoder output according to `new_order`.
29 |
30 | Args:
31 | encoder_out: output from the ``forward()`` method
32 | new_order (LongTensor): desired order
33 |
34 | Returns:
35 | `encoder_out` rearranged according to `new_order`
36 | """
37 | raise NotImplementedError
38 |
39 | def max_positions(self):
40 | """Maximum input length supported by the encoder."""
41 | return 1e6 # an arbitrary large number
42 |
43 | def upgrade_state_dict(self, state_dict):
44 | """Upgrade a (possibly old) state dict for new versions of fairseq."""
45 | return state_dict
46 |
--------------------------------------------------------------------------------
/fairseq/models/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | def skip_tensors(x, mask):
10 | """
11 | Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
12 | """
13 | if isinstance(x, int):
14 | return x
15 |
16 | if x is None:
17 | return None
18 |
19 | if isinstance(x, torch.Tensor):
20 | if x.size(0) == mask.size(0):
21 | return x[mask]
22 | elif x.size(1) == mask.size(0):
23 | return x[:, mask]
24 |
25 | if isinstance(x, list):
26 | return [skip_tensors(x_i, mask) for x_i in x]
27 |
28 | if isinstance(x, dict):
29 | return {k: skip_tensors(v, mask) for k, v in x.items()}
30 |
31 | raise NotImplementedError
32 |
33 |
34 | def expand_2d_or_3d_tensor(x, trg_dim, padding_idx):
35 | """
36 | Expand 2D/3D tensor on dim=1
37 | """
38 | if x is None:
39 | return None
40 |
41 | assert x.dim() == 2 or x.dim() == 3
42 | assert trg_dim >= x.size(1), (trg_dim, x.size())
43 | if trg_dim == x.size(1):
44 | return x
45 |
46 | dims = [x.size(0), trg_dim - x.size(1)]
47 | if x.dim() == 3:
48 | dims.append(x.size(2))
49 | x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
50 |
51 | return x
52 |
53 |
54 | def fill_tensors(x, mask, y, padding_idx):
55 | """
56 | Filling tensor x with y at masked positions (dim=0).
57 | """
58 | if x is None:
59 | return None
60 |
61 | assert x.dim() == y.dim() and mask.size(0) == x.size(0)
62 | assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
63 |
64 | n_selected = mask.sum()
65 | if n_selected == 0:
66 | return x
67 | assert n_selected == y.size(0)
68 | if n_selected == x.size(0):
69 | return y
70 |
71 | if x.size(1) < y.size(1):
72 | x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
73 | x[mask] = y
74 | elif x.size(1) > y.size(1):
75 | x[mask] = padding_idx
76 | if x.dim() == 2:
77 | x[mask, :y.size(1)] = y
78 | else:
79 | x[mask, :y.size(1), :] = y
80 | else:
81 | x[mask] = y
82 | return x
83 |
--------------------------------------------------------------------------------
/fairseq/models/roberta/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .hub_interface import * # noqa
7 | from .model import * # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .adaptive_input import AdaptiveInput
7 | from .adaptive_softmax import AdaptiveSoftmax
8 | from .beamable_mm import BeamableMM
9 | from .character_token_embedder import CharacterTokenEmbedder
10 | from .conv_tbc import ConvTBC
11 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention
12 | from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
13 | from .gelu import gelu, gelu_accurate
14 | from .grad_multiply import GradMultiply
15 | from .highway import Highway
16 | from .layer_norm import LayerNorm
17 | from .learned_positional_embedding import LearnedPositionalEmbedding
18 | from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
19 | from .linearized_convolution import LinearizedConvolution
20 | from .logsumexp_moe import LogSumExpMoE
21 | from .mean_pool_gating_network import MeanPoolGatingNetwork
22 | from .multihead_attention import MultiheadAttention
23 | from .positional_embedding import PositionalEmbedding
24 | from .scalar_bias import ScalarBias
25 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
26 | from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
27 | from .transformer_sentence_encoder import TransformerSentenceEncoder
28 | from .unfold import unfold1d
29 | from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
30 | from .vggblock import VGGBlock
31 |
32 | __all__ = [
33 | 'AdaptiveInput',
34 | 'AdaptiveSoftmax',
35 | 'BeamableMM',
36 | 'CharacterTokenEmbedder',
37 | 'ConvTBC',
38 | 'DownsampledMultiHeadAttention',
39 | 'DynamicConv1dTBC',
40 | 'DynamicConv',
41 | 'gelu',
42 | 'gelu_accurate',
43 | 'GradMultiply',
44 | 'Highway',
45 | 'LayerNorm',
46 | 'LearnedPositionalEmbedding',
47 | 'LightweightConv1dTBC',
48 | 'LightweightConv',
49 | 'LinearizedConvolution',
50 | 'LogSumExpMoE',
51 | 'MeanPoolGatingNetwork',
52 | 'MultiheadAttention',
53 | 'PositionalEmbedding',
54 | 'ScalarBias',
55 | 'SinusoidalPositionalEmbedding',
56 | 'TransformerSentenceEncoderLayer',
57 | 'TransformerSentenceEncoder',
58 | 'TransformerDecoderLayer',
59 | 'TransformerEncoderLayer',
60 | 'VGGBlock',
61 | 'unfold1d',
62 | ]
63 |
--------------------------------------------------------------------------------
/fairseq/modules/adaptive_input.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from typing import List
11 |
12 |
13 | class AdaptiveInput(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | vocab_size: int,
18 | padding_idx: int,
19 | initial_dim: int,
20 | factor: float,
21 | output_dim: int,
22 | cutoff: List[int],
23 | ):
24 | super().__init__()
25 |
26 | if vocab_size > cutoff[-1]:
27 | cutoff = cutoff + [vocab_size]
28 | else:
29 | assert vocab_size == cutoff[
30 | -1], 'cannot specify cutoff larger than vocab size'
31 |
32 | self.cutoff = cutoff
33 | self.embedding_dim = output_dim
34 | self.padding_idx = padding_idx
35 |
36 | self.embeddings = nn.ModuleList()
37 | for i in range(len(self.cutoff)):
38 | prev = self.cutoff[i - 1] if i > 0 else 0
39 | size = self.cutoff[i] - prev
40 | dim = int(initial_dim // (factor ** i))
41 | seq = nn.Sequential(
42 | nn.Embedding(size, dim, padding_idx),
43 | nn.Linear(dim, output_dim, bias=False)
44 | )
45 | self.embeddings.append(seq)
46 |
47 | def init_weights(m):
48 | if isinstance(m, nn.Embedding):
49 | nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5)
50 | nn.init.constant_(m.weight[padding_idx], 0)
51 | elif hasattr(m, 'weight'):
52 | nn.init.xavier_uniform_(m.weight)
53 |
54 | self.apply(init_weights)
55 |
56 | self.register_buffer('_float_tensor', torch.FloatTensor(1))
57 |
58 | def weights_for_band(self, band: int):
59 | return self.embeddings[band][0].weight, self.embeddings[band][1].weight
60 |
61 | def forward(self, input: torch.Tensor):
62 | result = self._float_tensor.new(input.shape + (self.embedding_dim,))
63 | for i in range(len(self.cutoff)):
64 | mask = input.lt(self.cutoff[i])
65 | if i > 0:
66 | mask.mul_(input.ge(self.cutoff[i - 1]))
67 | chunk_input = input[mask] - self.cutoff[i - 1]
68 | else:
69 | chunk_input = input[mask]
70 | if mask.any():
71 | result[mask] = self.embeddings[i](chunk_input)
72 | return result
73 |
--------------------------------------------------------------------------------
/fairseq/modules/beamable_mm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class BeamableMM(nn.Module):
11 | """This module provides an optimized MM for beam decoding with attention.
12 |
13 | It leverage the fact that the source-side of the input is replicated beam
14 | times and the target-side of the input is of width one. This layer speeds up
15 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
16 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
17 | """
18 | def __init__(self, beam_size=None):
19 | super(BeamableMM, self).__init__()
20 | self.beam_size = beam_size
21 |
22 | def forward(self, input1, input2):
23 | if (
24 | not self.training and # test mode
25 | self.beam_size is not None and # beam size is set
26 | input1.dim() == 3 and # only support batched input
27 | input1.size(1) == 1 # single time step update
28 | ):
29 | bsz, beam = input1.size(0), self.beam_size
30 |
31 | # bsz x 1 x nhu --> bsz/beam x beam x nhu
32 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
33 |
34 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
35 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
36 |
37 | # use non batched operation if bsz = beam
38 | if input1.size(0) == 1:
39 | output = torch.mm(input1[0, :, :], input2[0, :, :])
40 | else:
41 | output = input1.bmm(input2)
42 | return output.view(bsz, 1, -1)
43 | else:
44 | return input1.bmm(input2)
45 |
46 | def set_beam_size(self, beam_size):
47 | self.beam_size = beam_size
48 |
--------------------------------------------------------------------------------
/fairseq/modules/conv_tbc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from torch.nn.modules.utils import _single
8 |
9 |
10 | class ConvTBC(torch.nn.Module):
11 | """1D convolution over an input of shape (time x batch x channel)
12 |
13 | The implementation uses gemm to perform the convolution. This implementation
14 | is faster than cuDNN for small kernel sizes.
15 | """
16 | def __init__(self, in_channels, out_channels, kernel_size, padding=0):
17 | super(ConvTBC, self).__init__()
18 | self.in_channels = in_channels
19 | self.out_channels = out_channels
20 | self.kernel_size = _single(kernel_size)
21 | self.padding = _single(padding)
22 |
23 | self.weight = torch.nn.Parameter(torch.Tensor(
24 | self.kernel_size[0], in_channels, out_channels))
25 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
26 |
27 | def forward(self, input):
28 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0])
29 |
30 | def __repr__(self):
31 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
32 | ', padding={padding}')
33 | if self.bias is None:
34 | s += ', bias=False'
35 | s += ')'
36 | return s.format(name=self.__class__.__name__, **self.__dict__)
37 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .dynamicconv_layer import DynamicconvLayer # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | std::vector dynamicconv_cuda_forward(
12 | at::Tensor input,
13 | at::Tensor filters,
14 | int padding_l);
15 |
16 | std::vector dynamicconv_cuda_backward(
17 | at::Tensor gradOutput,
18 | int padding_l,
19 | at::Tensor input,
20 | at::Tensor filters);
21 |
22 |
23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
26 |
27 | std::vector dynamicconv_forward(
28 | at::Tensor input,
29 | at::Tensor filters,
30 | int padding_l) {
31 |
32 | CHECK_INPUT(input);
33 | CHECK_INPUT(filters);
34 |
35 | return dynamicconv_cuda_forward(input, filters,
36 | padding_l);
37 | }
38 |
39 | std::vector dynamicconv_backward(
40 | at::Tensor gradOutput,
41 | int padding_l,
42 | at::Tensor input,
43 | at::Tensor filters) {
44 |
45 | CHECK_INPUT(gradOutput);
46 | CHECK_INPUT(input);
47 | CHECK_INPUT(filters);
48 |
49 | return dynamicconv_cuda_backward(gradOutput, padding_l,
50 | input, filters);
51 | }
52 |
53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
54 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)");
55 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)");
56 | }
57 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #include
23 | #include
24 | #include
25 |
26 | #define SHFL_MASK 0xffffffff
27 |
28 | template
29 | __global__
30 | void dynamicconv_forward_kernel(const scalar_t* input,
31 | const scalar_t* weight,
32 | int minibatch,
33 | int sequenceLength,
34 | int numFeatures,
35 | int numFiltersInBlock,
36 | int numHeads,
37 | scalar_t* output);
38 |
39 | template
40 | __global__
41 | void dynamicconv_backward_kernel(
42 | const scalar_t* gradOutput, // B * C * T
43 | const scalar_t* input, // B * C * T
44 | const scalar_t* weight,
45 | int minibatch,
46 | int sequenceLength,
47 | int numFeatures,
48 | int numFiltersInBlock,
49 | int numHeads,
50 | scalar_t* gradWeight,
51 | scalar_t* gradInput); // B * H * k * T
52 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | std::vector dynamicconv_cpu_forward(
5 | float* input,
6 | float* filters,
7 | int padding_l);
8 |
9 | std::vector dynamicconv_cpu_backward(
10 | float* gradOutput,
11 | int padding_l,
12 | float* input,
13 | float* filters);
14 |
15 | std::vector dynamicconv_forward(
16 | float* input,
17 | float* filters,
18 | int padding_l) {
19 |
20 | return dynamicconv_cpu_forward(input, filters, padding_l);
21 | }
22 |
23 | std::vector dynamicconv_backward(
24 | float* gradOutput,
25 | int padding_l,
26 | float* input,
27 | float* filters) {
28 |
29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters);
30 | }
31 |
32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)");
34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)");
35 | }
36 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
9 |
10 | setup(
11 | name='dynamicconv_layer',
12 | ext_modules=[
13 | CUDAExtension(
14 | name='dynamicconv_cuda',
15 | sources=[
16 | 'dynamicconv_cuda.cpp',
17 | 'dynamicconv_cuda_kernel.cu',
18 | ],
19 | ),
20 | ],
21 | cmdclass={
22 | 'build_ext': BuildExtension
23 | })
24 |
--------------------------------------------------------------------------------
/fairseq/modules/gelu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs
8 | """
9 |
10 | import math
11 |
12 | import torch
13 |
14 |
15 | def gelu_accurate(x):
16 | if not hasattr(gelu_accurate, "_a"):
17 | gelu_accurate._a = math.sqrt(2 / math.pi)
18 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
19 |
20 |
21 | def gelu(x: torch.Tensor) -> torch.Tensor:
22 | if hasattr(torch.nn.functional, 'gelu'):
23 | return torch.nn.functional.gelu(x.float()).type_as(x)
24 | else:
25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
26 |
--------------------------------------------------------------------------------
/fairseq/modules/grad_multiply.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | class GradMultiply(torch.autograd.Function):
10 | @staticmethod
11 | def forward(ctx, x, scale):
12 | ctx.scale = scale
13 | res = x.new(x)
14 | return res
15 |
16 | @staticmethod
17 | def backward(ctx, grad):
18 | return grad * ctx.scale, None
19 |
--------------------------------------------------------------------------------
/fairseq/modules/highway.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from torch import nn
9 |
10 |
11 | class Highway(torch.nn.Module):
12 | """
13 | A `Highway layer `_.
14 | Adopted from the AllenNLP implementation.
15 | """
16 |
17 | def __init__(
18 | self,
19 | input_dim: int,
20 | num_layers: int = 1
21 | ):
22 | super(Highway, self).__init__()
23 | self.input_dim = input_dim
24 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2)
25 | for _ in range(num_layers)])
26 | self.activation = nn.ReLU()
27 |
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | for layer in self.layers:
32 | # As per comment in AllenNLP:
33 | # We should bias the highway layer to just carry its input forward. We do that by
34 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to
35 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half
36 | # of the bias vector in each Linear layer.
37 | nn.init.constant_(layer.bias[self.input_dim:], 1)
38 |
39 | nn.init.constant_(layer.bias[:self.input_dim], 0)
40 | nn.init.xavier_normal_(layer.weight)
41 |
42 | def forward(
43 | self,
44 | x: torch.Tensor
45 | ):
46 | for layer in self.layers:
47 | projection = layer(x)
48 | proj_x, gate = projection.chunk(2, dim=-1)
49 | proj_x = self.activation(proj_x)
50 | gate = torch.sigmoid(gate)
51 | x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
52 | return x
53 |
--------------------------------------------------------------------------------
/fairseq/modules/layer_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
10 | if not export and torch.cuda.is_available():
11 | try:
12 | from apex.normalization import FusedLayerNorm
13 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
14 | except ImportError:
15 | pass
16 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
17 |
--------------------------------------------------------------------------------
/fairseq/modules/learned_positional_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 | from fairseq import utils
9 |
10 |
11 | class LearnedPositionalEmbedding(nn.Embedding):
12 | """
13 | This module learns positional embeddings up to a fixed maximum size.
14 | Padding ids are ignored by either offsetting based on padding_idx
15 | or by setting padding_idx to None and ensuring that the appropriate
16 | position ids are passed to the forward function.
17 | """
18 |
19 | def __init__(
20 | self,
21 | num_embeddings: int,
22 | embedding_dim: int,
23 | padding_idx: int,
24 | ):
25 | super().__init__(num_embeddings, embedding_dim, padding_idx)
26 | self.onnx_trace = False
27 |
28 | def forward(self, input, incremental_state=None, positions=None):
29 | """Input is expected to be of size [bsz x seqlen]."""
30 | assert (
31 | (positions is None) or (self.padding_idx is None)
32 | ), "If positions is pre-computed then padding_idx should not be set."
33 |
34 | if positions is None:
35 | if incremental_state is not None:
36 | # positions is the same for every token when decoding a single step
37 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX
38 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1)))
39 | else:
40 | positions = utils.make_positions(
41 | input, self.padding_idx, onnx_trace=self.onnx_trace,
42 | )
43 | return super().forward(positions)
44 |
45 | def max_positions(self):
46 | """Maximum number of supported positions."""
47 | if self.padding_idx is not None:
48 | return self.num_embeddings - self.padding_idx - 1
49 | else:
50 | return self.num_embeddings
51 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .lightconv_layer import LightconvLayer # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/lightconv_cuda.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | std::vector lightconv_cuda_forward(
12 | at::Tensor input,
13 | at::Tensor filters,
14 | int padding_l);
15 |
16 | std::vector lightconv_cuda_backward(
17 | at::Tensor gradOutput,
18 | int padding_l,
19 | at::Tensor input,
20 | at::Tensor filters);
21 |
22 |
23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
26 |
27 | std::vector lightconv_forward(
28 | at::Tensor input,
29 | at::Tensor filters,
30 | int padding_l) {
31 |
32 | CHECK_INPUT(input);
33 | CHECK_INPUT(filters);
34 |
35 | return lightconv_cuda_forward(input, filters, padding_l);
36 | }
37 |
38 | std::vector lightconv_backward(
39 | at::Tensor gradOutput,
40 | int padding_l,
41 | at::Tensor input,
42 | at::Tensor filters) {
43 |
44 | CHECK_INPUT(gradOutput);
45 | CHECK_INPUT(input);
46 | CHECK_INPUT(filters);
47 |
48 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters);
49 | }
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)");
53 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)");
54 | }
55 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/lightconv_cuda.cuh:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 | #include
22 | #include
23 |
24 | #define SHFL_MASK 0xffffffff
25 |
26 | template
27 | __global__
28 | void lightconv_forward_kernel(const scalar_t* input,
29 | const scalar_t* filters,
30 | int minibatch, int sequenceLength,
31 | int numFeatures, int numFiltersInBlock,
32 | scalar_t* output);
33 |
34 | template
35 | __global__
36 | void lightconv_grad_wrt_input_kernel(
37 | const scalar_t* input,
38 | const scalar_t* filters,
39 | int minibatch,
40 | int sequenceLength,
41 | int numFeatures,
42 | int numFiltersInBlock,
43 | scalar_t* output);
44 |
45 | template
46 | __global__
47 | void lightconv_grad_wrt_weights_firstpass_short_kernel(
48 | const scalar_t* input,
49 | const scalar_t* gradInput,
50 | int minibatch,
51 | int sequenceLength,
52 | int numFeatures,
53 | int numFiltersInBlock,
54 | int numHeads,
55 | float* output);
56 |
57 | template
58 | __global__
59 | void lightconv_grad_wrt_weights_secondpass_short_kernel(
60 | const float* input,
61 | const int minibatch,
62 | const int numFiltersInBlock,
63 | scalar_t* output);
64 |
65 | template
66 | __global__
67 | void lightconv_grad_wrt_weights_firstpass_kernel(
68 | const scalar_t* input,
69 | const scalar_t* gradInput,
70 | int minibatch,
71 | int sequenceLength,
72 | int numFeatures,
73 | int numFiltersInBlock,
74 | float* output);
75 |
76 | template
77 | __global__
78 | void lightconv_grad_wrt_weights_secondpass_kernel(
79 | const float* input,
80 | const int minibatch,
81 | const int numFiltersInBlock,
82 | scalar_t* output);
83 |
84 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
9 |
10 | setup(
11 | name='lightconv_layer',
12 | ext_modules=[
13 | CUDAExtension('lightconv_cuda', [
14 | 'lightconv_cuda.cpp',
15 | 'lightconv_cuda_kernel.cu',
16 | ]),
17 | ],
18 | cmdclass={
19 | 'build_ext': BuildExtension
20 | })
21 |
--------------------------------------------------------------------------------
/fairseq/modules/logsumexp_moe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | class LogSumExpMoE(torch.autograd.Function):
10 | """Standard LogSumExp forward pass, but use *posterior* for the backward.
11 |
12 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
13 | (Shen et al., 2019) `_.
14 | """
15 |
16 | @staticmethod
17 | def forward(ctx, logp, posterior, dim=-1):
18 | ctx.save_for_backward(posterior)
19 | ctx.dim = dim
20 | return torch.logsumexp(logp, dim=dim)
21 |
22 | @staticmethod
23 | def backward(ctx, grad_output):
24 | posterior, = ctx.saved_tensors
25 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
26 | return grad_logp, None, None
27 |
--------------------------------------------------------------------------------
/fairseq/modules/mean_pool_gating_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 |
10 | class MeanPoolGatingNetwork(torch.nn.Module):
11 | """A simple mean-pooling gating network for selecting experts.
12 |
13 | This module applies mean pooling over an encoder's output and returns
14 | reponsibilities for each expert. The encoder format is expected to match
15 | :class:`fairseq.models.transformer.TransformerEncoder`.
16 | """
17 |
18 | def __init__(self, embed_dim, num_experts, dropout=None):
19 | super().__init__()
20 | self.embed_dim = embed_dim
21 | self.num_experts = num_experts
22 |
23 | self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
24 | self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
25 | self.fc2 = torch.nn.Linear(embed_dim, num_experts)
26 |
27 | def forward(self, encoder_out):
28 | if not (
29 | isinstance(encoder_out, dict)
30 | and 'encoder_out' in encoder_out
31 | and 'encoder_padding_mask' in encoder_out
32 | and encoder_out['encoder_out'].size(2) == self.embed_dim
33 | ):
34 | raise ValueError('Unexpected format for encoder_out')
35 |
36 | # mean pooling over time
37 | encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
38 | encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
39 | if encoder_padding_mask is not None:
40 | encoder_out = encoder_out.clone() # required because of transpose above
41 | encoder_out[encoder_padding_mask] = 0
42 | ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True)
43 | x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
44 | else:
45 | x = torch.mean(encoder_out, dim=1)
46 |
47 | x = torch.tanh(self.fc1(x))
48 | if self.dropout is not None:
49 | x = self.dropout(x)
50 | x = self.fc2(x)
51 | return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
52 |
--------------------------------------------------------------------------------
/fairseq/modules/positional_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 | from .learned_positional_embedding import LearnedPositionalEmbedding
9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
10 |
11 |
12 | def PositionalEmbedding(
13 | num_embeddings: int,
14 | embedding_dim: int,
15 | padding_idx: int,
16 | learned: bool = False,
17 | ):
18 | if learned:
19 | # if padding_idx is specified then offset the embedding ids by
20 | # this index and adjust num_embeddings appropriately
21 | # TODO: The right place for this offset would be inside
22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
23 | if padding_idx is not None:
24 | num_embeddings = num_embeddings + padding_idx + 1
25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
27 | if padding_idx is not None:
28 | nn.init.constant_(m.weight[padding_idx], 0)
29 | else:
30 | m = SinusoidalPositionalEmbedding(
31 | embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
32 | )
33 | return m
34 |
--------------------------------------------------------------------------------
/fairseq/modules/scalar_bias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | #
6 |
7 | import torch
8 |
9 |
10 | class ScalarBias(torch.autograd.Function):
11 | """
12 | Adds a vector of scalars, used in self-attention mechanism to allow
13 | the model to optionally attend to this vector instead of the past
14 | """
15 |
16 | @staticmethod
17 | def forward(ctx, input, dim, bias_init):
18 | size = list(input.size())
19 | size[dim] += 1
20 | output = input.new(*size).fill_(bias_init)
21 | output.narrow(dim, 1, size[dim] - 1).copy_(input)
22 | ctx.dim = dim
23 | return output
24 |
25 | @staticmethod
26 | def backward(ctx, grad):
27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
28 |
29 |
30 | def scalar_bias(input, dim, bias_init=0):
31 | return ScalarBias.apply(input, dim, bias_init)
32 |
--------------------------------------------------------------------------------
/fairseq/modules/sparse_transformer_sentence_encoder_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.modules import TransformerSentenceEncoderLayer
7 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
8 |
9 |
10 | class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
11 | """
12 | Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention)
13 | """
14 |
15 | def __init__(
16 | self,
17 | embedding_dim: float = 768,
18 | ffn_embedding_dim: float = 3072,
19 | num_attention_heads: float = 8,
20 | dropout: float = 0.1,
21 | attention_dropout: float = 0.1,
22 | activation_dropout: float = 0.1,
23 | activation_fn: str = 'relu',
24 | add_bias_kv: bool = False,
25 | add_zero_attn: bool = False,
26 | export: bool = False,
27 | is_bidirectional: bool = True,
28 | stride: int = 32,
29 | expressivity: int = 8,
30 | ) -> None:
31 |
32 | super().__init__(
33 | embedding_dim, ffn_embedding_dim, num_attention_heads, dropout,
34 | attention_dropout, activation_dropout, activation_fn, add_bias_kv,
35 | add_zero_attn, export
36 | )
37 |
38 | self.self_attn = SparseMultiheadAttention(
39 | self.embedding_dim,
40 | num_attention_heads,
41 | dropout=attention_dropout,
42 | add_bias_kv=add_bias_kv,
43 | add_zero_attn=add_zero_attn,
44 | self_attention=True,
45 | is_bidirectional=is_bidirectional,
46 | stride=stride,
47 | expressivity=expressivity,
48 | )
49 |
--------------------------------------------------------------------------------
/fairseq/modules/transformer_sentence_encoder_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from fairseq import utils
11 | from fairseq.modules import (
12 | LayerNorm,
13 | MultiheadAttention,
14 | )
15 |
16 |
17 | class TransformerSentenceEncoderLayer(nn.Module):
18 | """
19 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
20 | models.
21 | """
22 |
23 | def __init__(
24 | self,
25 | embedding_dim: float = 768,
26 | ffn_embedding_dim: float = 3072,
27 | num_attention_heads: float = 8,
28 | dropout: float = 0.1,
29 | attention_dropout: float = 0.1,
30 | activation_dropout: float = 0.1,
31 | activation_fn: str = 'relu',
32 | add_bias_kv: bool = False,
33 | add_zero_attn: bool = False,
34 | export: bool = False,
35 | ) -> None:
36 |
37 | super().__init__()
38 | # Initialize parameters
39 | self.embedding_dim = embedding_dim
40 | self.dropout = dropout
41 | self.activation_dropout = activation_dropout
42 |
43 | # Initialize blocks
44 | self.activation_fn = utils.get_activation_fn(activation_fn)
45 | self.self_attn = MultiheadAttention(
46 | self.embedding_dim,
47 | num_attention_heads,
48 | dropout=attention_dropout,
49 | add_bias_kv=add_bias_kv,
50 | add_zero_attn=add_zero_attn,
51 | self_attention=True
52 | )
53 |
54 | # layer norm associated with the self attention layer
55 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)
56 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
57 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
58 |
59 | # layer norm associated with the position wise feed-forward NN
60 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
61 |
62 | def forward(
63 | self,
64 | x: torch.Tensor,
65 | self_attn_mask: torch.Tensor = None,
66 | self_attn_padding_mask: torch.Tensor = None,
67 | ):
68 | """
69 | LayerNorm is applied either before or after the self-attention/ffn
70 | modules similar to the original Transformer imlementation.
71 | """
72 | residual = x
73 | x, attn = self.self_attn(
74 | query=x,
75 | key=x,
76 | value=x,
77 | key_padding_mask=self_attn_padding_mask,
78 | need_weights=False,
79 | attn_mask=self_attn_mask,
80 | )
81 | x = F.dropout(x, p=self.dropout, training=self.training)
82 | x = residual + x
83 | x = self.self_attn_layer_norm(x)
84 |
85 | residual = x
86 | x = self.activation_fn(self.fc1(x))
87 | x = F.dropout(x, p=self.activation_dropout, training=self.training)
88 | x = self.fc2(x)
89 | x = F.dropout(x, p=self.dropout, training=self.training)
90 | x = residual + x
91 | x = self.final_layer_norm(x)
92 | return x, attn
93 |
--------------------------------------------------------------------------------
/fairseq/modules/unfold.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn.functional as F
7 |
8 |
9 | def unfold1d(x, kernel_size, padding_l, pad_value=0):
10 | '''unfold T x B x C to T x B x C x K'''
11 | if kernel_size > 1:
12 | T, B, C = x.size()
13 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value)
14 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C))
15 | else:
16 | x = x.unsqueeze(3)
17 | return x
18 |
--------------------------------------------------------------------------------
/fairseq/optim/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from fairseq import registry
10 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer
11 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
12 | from fairseq.optim.bmuf import FairseqBMUF # noqa
13 |
14 |
15 | __all__ = [
16 | 'FairseqOptimizer',
17 | 'FP16Optimizer',
18 | 'MemoryEfficientFP16Optimizer',
19 | ]
20 |
21 |
22 | build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
23 | '--optimizer',
24 | base_class=FairseqOptimizer,
25 | default='nag',
26 | )
27 |
28 |
29 | # automatically import any Python files in the optim/ directory
30 | for file in os.listdir(os.path.dirname(__file__)):
31 | if file.endswith('.py') and not file.startswith('_'):
32 | module = file[:file.find('.py')]
33 | importlib.import_module('fairseq.optim.' + module)
34 |
--------------------------------------------------------------------------------
/fairseq/optim/adadelta.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import FairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer('adadelta')
12 | class Adadelta(FairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
22 | help='coefficient used for computing a running average of squared gradients')
23 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
24 | help='term added to the denominator to improve numerical stability')
25 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
26 | help='weight decay')
27 | parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps')
28 | # fmt: on
29 |
30 | @property
31 | def optimizer_config(self):
32 | """
33 | Return a kwarg dictionary that will be used to override optimizer
34 | args stored in checkpoints. This allows us to load a checkpoint and
35 | resume training using a different set of optimizer args, e.g., with a
36 | different learning rate.
37 | """
38 | return {
39 | 'lr': self.args.lr[0],
40 | 'rho': self.args.adadelta_rho,
41 | 'eps': self.args.adadelta_eps,
42 | 'weight_decay': self.args.weight_decay,
43 | }
44 |
--------------------------------------------------------------------------------
/fairseq/optim/adagrad.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import FairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer('adagrad')
12 | class Adagrad(FairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22 | help='weight decay')
23 | # fmt: on
24 |
25 | @property
26 | def optimizer_config(self):
27 | """
28 | Return a kwarg dictionary that will be used to override optimizer
29 | args stored in checkpoints. This allows us to load a checkpoint and
30 | resume training using a different set of optimizer args, e.g., with a
31 | different learning rate.
32 | """
33 | return {
34 | 'lr': self.args.lr[0],
35 | 'weight_decay': self.args.weight_decay,
36 | }
37 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from fairseq import registry
10 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler
11 |
12 |
13 | build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry(
14 | '--lr-scheduler',
15 | base_class=FairseqLRScheduler,
16 | default='fixed',
17 | )
18 |
19 | # automatically import any Python files in the optim/lr_scheduler/ directory
20 | for file in os.listdir(os.path.dirname(__file__)):
21 | if file.endswith('.py') and not file.startswith('_'):
22 | module = file[:file.find('.py')]
23 | importlib.import_module('fairseq.optim.lr_scheduler.' + module)
24 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .. import FairseqOptimizer
7 |
8 |
9 | class FairseqLRScheduler(object):
10 |
11 | def __init__(self, args, optimizer):
12 | super().__init__()
13 | if not isinstance(optimizer, FairseqOptimizer):
14 | raise ValueError('optimizer must be an instance of FairseqOptimizer')
15 | self.args = args
16 | self.optimizer = optimizer
17 | self.best = None
18 |
19 | @staticmethod
20 | def add_args(parser):
21 | """Add arguments to the parser for this LR scheduler."""
22 | pass
23 |
24 | def state_dict(self):
25 | """Return the LR scheduler state dict."""
26 | return {'best': self.best}
27 |
28 | def load_state_dict(self, state_dict):
29 | """Load an LR scheduler state dict."""
30 | self.best = state_dict['best']
31 |
32 | def step(self, epoch, val_loss=None):
33 | """Update the learning rate at the end of the given epoch."""
34 | if val_loss is not None:
35 | if self.best is None:
36 | self.best = val_loss
37 | else:
38 | self.best = min(self.best, val_loss)
39 |
40 | def step_update(self, num_updates):
41 | """Update the learning rate after each update."""
42 | return self.optimizer.get_lr()
43 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/fixed_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import FairseqLRScheduler, register_lr_scheduler
7 |
8 |
9 | @register_lr_scheduler('fixed')
10 | class FixedSchedule(FairseqLRScheduler):
11 | """Decay the LR on a fixed schedule."""
12 |
13 | def __init__(self, args, optimizer):
14 | super().__init__(args, optimizer)
15 |
16 | # set defaults
17 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
18 |
19 | self.lr = args.lr[0]
20 | if args.warmup_updates > 0:
21 | self.warmup_factor = 1. / args.warmup_updates
22 | else:
23 | self.warmup_factor = 1
24 |
25 | @staticmethod
26 | def add_args(parser):
27 | """Add arguments to the parser for this LR scheduler."""
28 | # fmt: off
29 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
30 | help='force annealing at specified epoch')
31 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
32 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
33 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
34 | help='warmup the learning rate linearly for the first N updates')
35 | # fmt: on
36 |
37 | def get_next_lr(self, epoch):
38 | lrs = self.args.lr
39 | if self.args.force_anneal is None or epoch < self.args.force_anneal:
40 | # use fixed LR schedule
41 | next_lr = lrs[min(epoch, len(lrs) - 1)]
42 | else:
43 | # annneal based on lr_shrink
44 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
45 | return next_lr
46 |
47 | def step(self, epoch, val_loss=None):
48 | """Update the learning rate at the end of the given epoch."""
49 | super().step(epoch, val_loss)
50 | self.lr = self.get_next_lr(epoch)
51 | self.optimizer.set_lr(self.warmup_factor * self.lr)
52 | return self.optimizer.get_lr()
53 |
54 | def step_update(self, num_updates):
55 | """Update the learning rate after each update."""
56 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
57 | self.warmup_factor = num_updates / float(self.args.warmup_updates)
58 | self.optimizer.set_lr(self.warmup_factor * self.lr)
59 | return self.optimizer.get_lr()
60 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import FairseqLRScheduler, register_lr_scheduler
7 |
8 |
9 | @register_lr_scheduler('inverse_sqrt')
10 | class InverseSquareRootSchedule(FairseqLRScheduler):
11 | """Decay the LR based on the inverse square root of the update number.
12 |
13 | We also support a warmup phase where we linearly increase the learning rate
14 | from some initial learning rate (``--warmup-init-lr``) until the configured
15 | learning rate (``--lr``). Thereafter we decay proportional to the number of
16 | updates, with a decay factor set to align with the configured learning rate.
17 |
18 | During warmup::
19 |
20 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
21 | lr = lrs[update_num]
22 |
23 | After warmup::
24 |
25 | decay_factor = args.lr * sqrt(args.warmup_updates)
26 | lr = decay_factor / sqrt(update_num)
27 | """
28 |
29 | def __init__(self, args, optimizer):
30 | super().__init__(args, optimizer)
31 | if len(args.lr) > 1:
32 | raise ValueError(
33 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.'
34 | ' Consider --lr-scheduler=fixed instead.'
35 | )
36 | warmup_end_lr = args.lr[0]
37 | if args.warmup_init_lr < 0:
38 | args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
39 |
40 | # linearly warmup for the first args.warmup_updates
41 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
42 |
43 | # then, decay prop. to the inverse square root of the update number
44 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5
45 |
46 | # initial learning rate
47 | self.lr = args.warmup_init_lr
48 | self.optimizer.set_lr(self.lr)
49 |
50 | @staticmethod
51 | def add_args(parser):
52 | """Add arguments to the parser for this LR scheduler."""
53 | # fmt: off
54 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
55 | help='warmup the learning rate linearly for the first N updates')
56 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
57 | help='initial learning rate during warmup phase; default is args.lr')
58 | # fmt: on
59 |
60 | def step(self, epoch, val_loss=None):
61 | """Update the learning rate at the end of the given epoch."""
62 | super().step(epoch, val_loss)
63 | # we don't change the learning rate at epoch boundaries
64 | return self.optimizer.get_lr()
65 |
66 | def step_update(self, num_updates):
67 | """Update the learning rate after each update."""
68 | if num_updates < self.args.warmup_updates:
69 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
70 | else:
71 | self.lr = self.decay_factor * num_updates**-0.5
72 | self.optimizer.set_lr(self.lr)
73 | return self.lr
74 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import FairseqLRScheduler, register_lr_scheduler
7 |
8 |
9 | @register_lr_scheduler('polynomial_decay')
10 | class PolynomialDecaySchedule(FairseqLRScheduler):
11 | """Decay the LR on a fixed schedule."""
12 |
13 | def __init__(self, args, optimizer):
14 | super().__init__(args, optimizer)
15 |
16 | # set defaults
17 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
18 |
19 | self.lr = args.lr[0]
20 | if args.warmup_updates > 0:
21 | self.warmup_factor = 1. / args.warmup_updates
22 | else:
23 | self.warmup_factor = 1
24 | self.end_learning_rate = args.end_learning_rate
25 | self.total_num_update = args.total_num_update
26 | self.power = args.power
27 | self.optimizer.set_lr(self.warmup_factor * self.lr)
28 |
29 | @staticmethod
30 | def add_args(parser):
31 | """Add arguments to the parser for this LR scheduler."""
32 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
33 | help='force annealing at specified epoch')
34 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
35 | help='warmup the learning rate linearly for the first N updates')
36 | parser.add_argument('--end-learning-rate', default=0.0, type=float)
37 | parser.add_argument('--power', default=1.0, type=float)
38 | parser.add_argument('--total-num-update', default=1000000, type=int)
39 |
40 | def get_next_lr(self, epoch):
41 | lrs = self.args.lr
42 | if self.args.force_anneal is None or epoch < self.args.force_anneal:
43 | # use fixed LR schedule
44 | next_lr = lrs[min(epoch, len(lrs) - 1)]
45 | else:
46 | # annneal based on lr_shrink
47 | next_lr = self.optimizer.get_lr()
48 | return next_lr
49 |
50 | def step(self, epoch, val_loss=None):
51 | """Update the learning rate at the end of the given epoch."""
52 | super().step(epoch, val_loss)
53 | self.lr = self.get_next_lr(epoch)
54 | self.optimizer.set_lr(self.warmup_factor * self.lr)
55 | return self.optimizer.get_lr()
56 |
57 | def step_update(self, num_updates):
58 | """Update the learning rate after each update."""
59 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
60 | self.warmup_factor = num_updates / float(self.args.warmup_updates)
61 | lr = self.warmup_factor * self.lr
62 | elif num_updates >= self.total_num_update:
63 | lr = self.end_learning_rate
64 | else:
65 | warmup = self.args.warmup_updates
66 | lr_range = self.lr - self.end_learning_rate
67 | pct_remaining = 1 - (num_updates - warmup) / (self.total_num_update - warmup)
68 | lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate
69 | self.optimizer.set_lr(lr)
70 | return self.optimizer.get_lr()
71 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim.lr_scheduler
7 |
8 | from . import FairseqLRScheduler, register_lr_scheduler
9 |
10 |
11 | @register_lr_scheduler('reduce_lr_on_plateau')
12 | class ReduceLROnPlateau(FairseqLRScheduler):
13 | """Decay the LR by a factor every time the validation loss plateaus."""
14 |
15 | def __init__(self, args, optimizer):
16 | super().__init__(args, optimizer)
17 | if len(args.lr) > 1:
18 | raise ValueError(
19 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
20 | ' Consider --lr-scheduler=fixed instead.'
21 | )
22 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
23 | self.optimizer.optimizer, patience=0, factor=args.lr_shrink,
24 | threshold=args.lr_threshold)
25 |
26 | @staticmethod
27 | def add_args(parser):
28 | """Add arguments to the parser for this LR scheduler."""
29 | # fmt: off
30 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
31 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
32 | parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT',
33 | help='Threshold for measuring the new optimum, \
34 | to only focus on significant changes')
35 | # fmt: on
36 |
37 | def state_dict(self):
38 | """Return the LR scheduler state dict."""
39 | return {
40 | 'best': self.lr_scheduler.best,
41 | 'last_epoch': self.lr_scheduler.last_epoch,
42 | }
43 |
44 | def load_state_dict(self, state_dict):
45 | """Load an LR scheduler state dict."""
46 | self.lr_scheduler.best = state_dict['best']
47 | if 'last_epoch' in state_dict:
48 | self.lr_scheduler.last_epoch = state_dict['last_epoch']
49 |
50 | def step(self, epoch, val_loss=None):
51 | """Update the learning rate at the end of the given epoch."""
52 | if val_loss is not None:
53 | self.lr_scheduler.step(val_loss, epoch)
54 | else:
55 | self.lr_scheduler.last_epoch = epoch
56 | return self.optimizer.get_lr()
57 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 |
8 | from . import FairseqLRScheduler, register_lr_scheduler
9 |
10 |
11 | @register_lr_scheduler('triangular')
12 | class TriangularSchedule(FairseqLRScheduler):
13 | """Assign LR based on a triangular cyclical schedule.
14 |
15 | See https://arxiv.org/pdf/1506.01186.pdf for details.
16 | """
17 |
18 | def __init__(self, args, optimizer):
19 | super().__init__(args, optimizer)
20 | if len(args.lr) > 1:
21 | raise ValueError(
22 | 'Cannot use a fixed learning rate schedule with triangular.'
23 | ' Consider --lr-scheduler=fixed instead.'
24 | )
25 |
26 | lr = args.lr[0]
27 |
28 | assert args.max_lr > lr, 'max_lr must be more than lr'
29 | self.min_lr = lr
30 | self.max_lr = args.max_lr
31 | self.stepsize = args.lr_period_updates // 2
32 | self.lr_shrink = args.lr_shrink
33 | self.shrink_min = args.shrink_min
34 |
35 | # initial learning rate
36 | self.lr = self.min_lr
37 | self.optimizer.set_lr(self.lr)
38 |
39 | @staticmethod
40 | def add_args(parser):
41 | """Add arguments to the parser for this LR scheduler."""
42 | # fmt: off
43 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
44 | help='max learning rate, must be more than args.lr')
45 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
46 | help='initial number of updates per period (cycle length)')
47 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
48 | help='shrink factor for annealing')
49 | parser.add_argument('--shrink-min', action='store_true',
50 | help='if set, also shrinks min lr')
51 | # fmt: on
52 |
53 | def step(self, epoch, val_loss=None):
54 | """Update the learning rate at the end of the given epoch."""
55 | super().step(epoch, val_loss)
56 | # we don't change the learning rate at epoch boundaries
57 | return self.optimizer.get_lr()
58 |
59 | def step_update(self, num_updates):
60 | """Update the learning rate after each update."""
61 | cycle = math.floor(num_updates / (2 * self.stepsize))
62 |
63 | lr_shrink = self.lr_shrink ** cycle
64 | max_lr = self.max_lr * lr_shrink
65 | if self.shrink_min:
66 | min_lr = self.min_lr * lr_shrink
67 | else:
68 | min_lr = self.min_lr
69 |
70 | x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1)
71 | self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x))
72 |
73 | self.optimizer.set_lr(self.lr)
74 | return self.lr
75 |
--------------------------------------------------------------------------------
/fairseq/optim/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import FairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer('sgd')
12 | class SGD(FairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
22 | help='momentum factor')
23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
24 | help='weight decay')
25 | # fmt: on
26 |
27 | @property
28 | def optimizer_config(self):
29 | """
30 | Return a kwarg dictionary that will be used to override optimizer
31 | args stored in checkpoints. This allows us to load a checkpoint and
32 | resume training using a different set of optimizer args, e.g., with a
33 | different learning rate.
34 | """
35 | return {
36 | 'lr': self.args.lr[0],
37 | 'momentum': self.args.momentum,
38 | 'weight_decay': self.args.weight_decay,
39 | }
40 |
--------------------------------------------------------------------------------
/fairseq/pdb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import multiprocessing
7 | import os
8 | import pdb
9 | import sys
10 |
11 |
12 | __all__ = ['set_trace']
13 |
14 |
15 | _stdin = [None]
16 | _stdin_lock = multiprocessing.Lock()
17 | try:
18 | _stdin_fd = sys.stdin.fileno()
19 | except Exception:
20 | _stdin_fd = None
21 |
22 |
23 | class MultiprocessingPdb(pdb.Pdb):
24 | """A Pdb wrapper that works in a multiprocessing environment.
25 |
26 | Usage: `from fairseq import pdb; pdb.set_trace()`
27 | """
28 |
29 | def __init__(self):
30 | pdb.Pdb.__init__(self, nosigint=True)
31 |
32 | def _cmdloop(self):
33 | stdin_bak = sys.stdin
34 | with _stdin_lock:
35 | try:
36 | if _stdin_fd is not None:
37 | if not _stdin[0]:
38 | _stdin[0] = os.fdopen(_stdin_fd)
39 | sys.stdin = _stdin[0]
40 | self.cmdloop()
41 | finally:
42 | sys.stdin = stdin_bak
43 |
44 |
45 | def set_trace():
46 | pdb = MultiprocessingPdb()
47 | pdb.set_trace(sys._getframe().f_back)
48 |
--------------------------------------------------------------------------------
/fairseq/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import argparse
7 |
8 |
9 | REGISTRIES = {}
10 |
11 |
12 | def setup_registry(
13 | registry_name: str,
14 | base_class=None,
15 | default=None,
16 | ):
17 | assert registry_name.startswith('--')
18 | registry_name = registry_name[2:].replace('-', '_')
19 |
20 | REGISTRY = {}
21 | REGISTRY_CLASS_NAMES = set()
22 |
23 | # maintain a registry of all registries
24 | if registry_name in REGISTRIES:
25 | return # registry already exists
26 | REGISTRIES[registry_name] = {
27 | 'registry': REGISTRY,
28 | 'default': default,
29 | }
30 |
31 | def build_x(args, *extra_args, **extra_kwargs):
32 | choice = getattr(args, registry_name, None)
33 | if choice is None:
34 | return None
35 | cls = REGISTRY[choice]
36 | if hasattr(cls, 'build_' + registry_name):
37 | builder = getattr(cls, 'build_' + registry_name)
38 | else:
39 | builder = cls
40 | set_defaults(args, cls)
41 | return builder(args, *extra_args, **extra_kwargs)
42 |
43 | def register_x(name):
44 |
45 | def register_x_cls(cls):
46 | if name in REGISTRY:
47 | raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name))
48 | if cls.__name__ in REGISTRY_CLASS_NAMES:
49 | raise ValueError(
50 | 'Cannot register {} with duplicate class name ({})'.format(
51 | registry_name, cls.__name__,
52 | )
53 | )
54 | if base_class is not None and not issubclass(cls, base_class):
55 | raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__))
56 | REGISTRY[name] = cls
57 | REGISTRY_CLASS_NAMES.add(cls.__name__)
58 | return cls
59 |
60 | return register_x_cls
61 |
62 | return build_x, register_x, REGISTRY
63 |
64 |
65 | def set_defaults(args, cls):
66 | """Helper to set default arguments based on *add_args*."""
67 | if not hasattr(cls, 'add_args'):
68 | return
69 | parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
70 | cls.add_args(parser)
71 | # copied from argparse.py:
72 | defaults = argparse.Namespace()
73 | for action in parser._actions:
74 | if action.dest is not argparse.SUPPRESS:
75 | if not hasattr(defaults, action.dest):
76 | if action.default is not argparse.SUPPRESS:
77 | setattr(defaults, action.dest, action.default)
78 | for key, default_value in vars(defaults).items():
79 | if not hasattr(args, key):
80 | setattr(args, key, default_value)
81 |
--------------------------------------------------------------------------------
/fairseq/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import argparse
7 | import importlib
8 | import os
9 |
10 | from .fairseq_task import FairseqTask
11 |
12 | TASK_REGISTRY = {}
13 | TASK_CLASS_NAMES = set()
14 |
15 |
16 | def setup_task(args, **kwargs):
17 | return TASK_REGISTRY[args.task].setup_task(args, **kwargs)
18 |
19 |
20 | def register_task(name):
21 | """
22 | New tasks can be added to fairseq with the
23 | :func:`~fairseq.tasks.register_task` function decorator.
24 |
25 | For example::
26 |
27 | @register_task('classification')
28 | class ClassificationTask(FairseqTask):
29 | (...)
30 |
31 | .. note::
32 |
33 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask`
34 | interface.
35 |
36 | Please see the
37 |
38 | Args:
39 | name (str): the name of the task
40 | """
41 |
42 | def register_task_cls(cls):
43 | if name in TASK_REGISTRY:
44 | raise ValueError('Cannot register duplicate task ({})'.format(name))
45 | if not issubclass(cls, FairseqTask):
46 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
47 | if cls.__name__ in TASK_CLASS_NAMES:
48 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
49 | TASK_REGISTRY[name] = cls
50 | TASK_CLASS_NAMES.add(cls.__name__)
51 | return cls
52 |
53 | return register_task_cls
54 |
55 |
56 | # automatically import any Python files in the tasks/ directory
57 | for file in os.listdir(os.path.dirname(__file__)):
58 | if file.endswith('.py') and not file.startswith('_'):
59 | task_name = file[:file.find('.py')]
60 | importlib.import_module('fairseq.tasks.' + task_name)
61 |
62 | # expose `task_parser` for sphinx
63 | if task_name in TASK_REGISTRY:
64 | parser = argparse.ArgumentParser(add_help=False)
65 | group_task = parser.add_argument_group('Task name')
66 | # fmt: off
67 | group_task.add_argument('--task', metavar=task_name,
68 | help='Enable this task with: ``--task=' + task_name + '``')
69 | # fmt: on
70 | group_args = parser.add_argument_group('Additional command-line arguments')
71 | TASK_REGISTRY[task_name].add_args(group_args)
72 | globals()[task_name + '_parser'] = parser
73 |
74 |
75 | def get_task(name):
76 | return TASK_REGISTRY[name]
77 |
--------------------------------------------------------------------------------
/fairseq/tasks/audio_pretraining.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 |
8 | from fairseq.data import FileAudioDataset
9 | from . import FairseqTask, register_task
10 |
11 |
12 | @register_task('audio_pretraining')
13 | class AudioPretrainingTask(FairseqTask):
14 | """
15 |
16 | """
17 |
18 | @staticmethod
19 | def add_args(parser):
20 | """Add task-specific arguments to the parser."""
21 | parser.add_argument('data', help='path to data directory')
22 | parser.add_argument('--sample-rate', default=16000, type=int,
23 | help='target sample rate. audio files will be up/down sampled to this rate')
24 | parser.add_argument('--max-sample-size', default=None, type=int,
25 | help='max sample size to crop to for batching. default = min sample length')
26 | parser.add_argument('--min-sample-size', default=None, type=int,
27 | help='min sample size to crop to for batching. default = same as --max-sample-size')
28 |
29 | def __init__(self, args):
30 | super().__init__(args)
31 |
32 | @classmethod
33 | def setup_task(cls, args, **kwargs):
34 | """Setup the task (e.g., load dictionaries).
35 |
36 | Args:
37 | args (argparse.Namespace): parsed command-line arguments
38 | """
39 | return cls(args)
40 |
41 | def load_dataset(self, split, **kwargs):
42 | """Load a given dataset split.
43 |
44 | Args:
45 | split (str): name of the split (e.g., train, valid, test)
46 | """
47 |
48 | manifest = os.path.join(self.args.data, '{}.tsv'.format(split))
49 | self.datasets[split] = FileAudioDataset(manifest,
50 | sample_rate=self.args.sample_rate,
51 | max_sample_size=self.args.max_sample_size,
52 | min_sample_size=self.args.min_sample_size)
53 |
54 | @property
55 | def target_dictionary(self):
56 | """Return the :class:`~fairseq.data.Dictionary` for the language
57 | model."""
58 | return None
59 |
--------------------------------------------------------------------------------
/fairseq/tasks/translation_from_pretrained_xlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
7 | from fairseq.tasks.translation import TranslationTask
8 |
9 | from . import register_task
10 |
11 |
12 | @register_task("translation_from_pretrained_xlm")
13 | class TranslationFromPretrainedXLMTask(TranslationTask):
14 | """
15 | Same as TranslationTask except use the MaskedLMDictionary class so that
16 | we can load data that was binarized with the MaskedLMDictionary class.
17 |
18 | This task should be used for the entire training pipeline when we want to
19 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
20 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation
21 | of that trained model.
22 | """
23 |
24 | @classmethod
25 | def load_dictionary(cls, filename):
26 | """Load the masked LM dictionary from the filename
27 |
28 | Args:
29 | filename (str): the filename
30 | """
31 | return MaskedLMDictionary.load(filename)
32 |
--------------------------------------------------------------------------------
/fairseq/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import re
7 |
8 | SPACE_NORMALIZER = re.compile(r"\s+")
9 |
10 |
11 | def tokenize_line(line):
12 | line = SPACE_NORMALIZER.sub(" ", line)
13 | line = line.strip()
14 | return line.split()
15 |
--------------------------------------------------------------------------------
/fairseq_cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/NA-MNMT/120843696d6d9ae24b7ba12fbaf85158512ee714/fairseq_cli/__init__.py
--------------------------------------------------------------------------------
/fairseq_cli/eval_lm.py:
--------------------------------------------------------------------------------
1 | ../eval_lm.py
--------------------------------------------------------------------------------
/fairseq_cli/generate.py:
--------------------------------------------------------------------------------
1 | ../generate.py
--------------------------------------------------------------------------------
/fairseq_cli/interactive.py:
--------------------------------------------------------------------------------
1 | ../interactive.py
--------------------------------------------------------------------------------
/fairseq_cli/preprocess.py:
--------------------------------------------------------------------------------
1 | ../preprocess.py
--------------------------------------------------------------------------------
/fairseq_cli/score.py:
--------------------------------------------------------------------------------
1 | ../score.py
--------------------------------------------------------------------------------
/fairseq_cli/setup.py:
--------------------------------------------------------------------------------
1 | ../setup.py
--------------------------------------------------------------------------------
/fairseq_cli/train.py:
--------------------------------------------------------------------------------
1 | ../train.py
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import functools
7 |
8 | from fairseq.hub_utils import BPEHubInterface as bpe # noqa
9 | from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa
10 | from fairseq.models import MODEL_REGISTRY
11 |
12 |
13 | dependencies = [
14 | 'numpy',
15 | 'regex',
16 | 'requests',
17 | 'torch',
18 | ]
19 |
20 |
21 | for _model_type, _cls in MODEL_REGISTRY.items():
22 | for model_name in _cls.hub_models().keys():
23 | globals()[model_name] = functools.partial(
24 | _cls.from_pretrained,
25 | model_name,
26 | )
27 | # to simplify the interface we only expose named models
28 | # globals()[_model_type] = _cls.from_pretrained
29 |
--------------------------------------------------------------------------------
/importanceModel.sh:
--------------------------------------------------------------------------------
1 | save_dir={checkpoints}/model.pt
2 | langs=lang1,lang2,lang3,lang4 # All languages pairs of your model such as "en-zh,zh-en,en-ar,ar-en"
3 | lang=lang1 # Current language pair, such as "en-zh"
4 |
5 | python importance.py data-bin/{data} \
6 | --arch multilingual_transformer --reset-optimizer \
7 | --encoder-langtok "tgt" \
8 | --task multilingual_translation --lang-pairs $langs \
9 | --share-encoders --share-decoders \
10 | --share-all-embeddings \
11 | --focus-lang $lang --fp16 \
12 | --max-tokens 2048 --save-dir $save_dir
13 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | model={path_to_ckpt}
2 | #model=test
3 |
4 | python3 train.py data-bin/{data}} \
5 | --arch multilingual_transformer \
6 | --fp16 \
7 | --encoder-langtok "tgt" \
8 | --restore-file /path_to_baseline/model.pt \
9 | --task multilingual_translation --lang-pairs $langs \
10 | --share-encoders --share-decoders \
11 | --share-all-embeddings --share-decoder-input-output-embed \
12 | --reset-lr-scheduler --reset-optimizer \
13 | --optimizer adam --adam-betas '(0.9, 0.98)' \
14 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
15 | --lr 0.0001 --min-lr 1e-09 --ddp-backend=no_c10d \
16 | --dropout 0.1 \
17 | --weight-decay 0.0 --clip-norm 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
18 | --max-tokens 4096 --update-freq 2 \
19 | --no-progress-bar --log-format json --log-interval 20 \
20 | --save-dir checkpoints/$model |tee -a logs/$model.log
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/NA-MNMT/120843696d6d9ae24b7ba12fbaf85158512ee714/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/compare_namespaces.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Helper script to compare two argparse.Namespace objects."""
3 |
4 | from argparse import Namespace # noqa
5 |
6 |
7 | def main():
8 |
9 | ns1 = eval(input('Namespace 1: '))
10 | ns2 = eval(input('Namespace 2: '))
11 |
12 | def keys(ns):
13 | ks = set()
14 | for k in dir(ns):
15 | if not k.startswith('_'):
16 | ks.add(k)
17 | return ks
18 |
19 | k1 = keys(ns1)
20 | k2 = keys(ns2)
21 |
22 | def print_keys(ks, ns1, ns2=None):
23 | for k in ks:
24 | if ns2 is None:
25 | print('{}\t{}'.format(k, getattr(ns1, k, None)))
26 | else:
27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None)))
28 |
29 | print('Keys unique to namespace 1:')
30 | print_keys(k1 - k2, ns1)
31 | print()
32 |
33 | print('Keys unique to namespace 2:')
34 | print_keys(k2 - k1, ns2)
35 | print()
36 |
37 | print('Overlapping keys with different values:')
38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')]
39 | print_keys(ks, ns1, ns2)
40 | print()
41 |
42 |
43 | if __name__ == '__main__':
44 | main()
45 |
--------------------------------------------------------------------------------
/scripts/compound_split_bleu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 1 ]; then
4 | echo "usage: $0 GENERATE_PY_OUTPUT"
5 | exit 1
6 | fi
7 |
8 | GEN=$1
9 |
10 | SYS=$GEN.sys
11 | REF=$GEN.ref
12 |
13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
14 | echo "not done generating"
15 | exit
16 | fi
17 |
18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
20 | fairseq-score --sys $SYS --ref $REF
21 |
--------------------------------------------------------------------------------
/scripts/convert_dictionary.lua:
--------------------------------------------------------------------------------
1 | -- Copyright (c) Facebook, Inc. and its affiliates.
2 | --
3 | -- This source code is licensed under the MIT license found in the
4 | -- LICENSE file in the root directory of this source tree.
5 | --
6 | -- Usage: convert_dictionary.lua
7 | require 'fairseq'
8 | require 'torch'
9 | require 'paths'
10 |
11 | if #arg < 1 then
12 | print('usage: convert_dictionary.lua ')
13 | os.exit(1)
14 | end
15 | if not paths.filep(arg[1]) then
16 | print('error: file does not exit: ' .. arg[1])
17 | os.exit(1)
18 | end
19 |
20 | dict = torch.load(arg[1])
21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt')
22 | assert(dst:match('.txt$'))
23 |
24 | f = io.open(dst, 'w')
25 | for idx, symbol in ipairs(dict.index_to_symbol) do
26 | if idx > dict.cutoff then
27 | break
28 | end
29 | f:write(symbol)
30 | f:write(' ')
31 | f:write(dict.index_to_freq[idx])
32 | f:write('\n')
33 | end
34 | f:close()
35 |
--------------------------------------------------------------------------------
/scripts/count_docs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Count the number of documents and average number of lines and tokens per
8 | document in a large file. Documents should be separated by a single empty line.
9 | """
10 |
11 | import argparse
12 | import gzip
13 | import sys
14 |
15 | import numpy as np
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('input')
21 | parser.add_argument('--gzip', action='store_true')
22 | args = parser.parse_args()
23 |
24 | def gopen():
25 | if args.gzip:
26 | return gzip.open(args.input, 'r')
27 | else:
28 | return open(args.input, 'r', encoding='utf-8')
29 |
30 | num_lines = []
31 | num_toks = []
32 | with gopen() as h:
33 | num_docs = 1
34 | num_lines_in_doc = 0
35 | num_toks_in_doc = 0
36 | for i, line in enumerate(h):
37 | if len(line.strip()) == 0: # empty line indicates new document
38 | num_docs += 1
39 | num_lines.append(num_lines_in_doc)
40 | num_toks.append(num_toks_in_doc)
41 | num_lines_in_doc = 0
42 | num_toks_in_doc = 0
43 | else:
44 | num_lines_in_doc += 1
45 | num_toks_in_doc += len(line.rstrip().split())
46 | if i % 1000000 == 0:
47 | print(i, file=sys.stderr, end="", flush=True)
48 | elif i % 100000 == 0:
49 | print(".", file=sys.stderr, end="", flush=True)
50 | print(file=sys.stderr, flush=True)
51 |
52 | print("found {} docs".format(num_docs))
53 | print("average num lines per doc: {}".format(np.mean(num_lines)))
54 | print("average num toks per doc: {}".format(np.mean(num_toks)))
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/scripts/read_binarized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 |
9 | from fairseq.data import data_utils, Dictionary, indexed_dataset
10 |
11 |
12 | def get_parser():
13 | parser = argparse.ArgumentParser(
14 | description='writes text from binarized file to stdout')
15 | # fmt: off
16 | parser.add_argument('--dataset-impl', help='dataset implementation',
17 | choices=indexed_dataset.get_available_dataset_impl())
18 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
19 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
20 | # fmt: on
21 |
22 | return parser
23 |
24 |
25 | def main():
26 | parser = get_parser()
27 | args = parser.parse_args()
28 |
29 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None
30 | dataset = data_utils.load_indexed_dataset(
31 | args.input,
32 | dictionary,
33 | dataset_impl=args.dataset_impl,
34 | default='lazy',
35 | )
36 |
37 | for tensor_line in dataset:
38 | if dictionary is None:
39 | line = ' '.join([str(int(x)) for x in tensor_line])
40 | else:
41 | line = dictionary.string(tensor_line)
42 |
43 | print(line)
44 |
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/scripts/sacrebleu_pregen.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 4 ]; then
4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
5 | exit 1
6 | fi
7 |
8 | TESTSET=$1
9 | SRCLANG=$2
10 | TGTLANG=$3
11 |
12 | GEN=$4
13 |
14 | echo 'Cloning Moses github repository (for tokenization scripts)...'
15 | git clone https://github.com/moses-smt/mosesdecoder.git
16 |
17 | SCRIPTS=mosesdecoder/scripts
18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
19 |
20 | grep ^H $GEN \
21 | | sed 's/^H\-//' \
22 | | sort -n -k 1 \
23 | | cut -f 3 \
24 | | perl $DETOKENIZER -l $TGTLANG \
25 | | sed "s/ - /-/g" \
26 | > $GEN.sorted.detok
27 |
28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok
29 |
--------------------------------------------------------------------------------
/scripts/shard_docs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Split a large file into shards while respecting document boundaries. Documents
8 | should be separated by a single empty line.
9 | """
10 |
11 | import argparse
12 | import contextlib
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('input')
18 | parser.add_argument('--num-shards', type=int)
19 | args = parser.parse_args()
20 |
21 | assert args.num_shards is not None and args.num_shards > 1
22 |
23 | with open(args.input, 'r', encoding='utf-8') as h:
24 | with contextlib.ExitStack() as stack:
25 | outputs = [
26 | stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8"))
27 | for i in range(args.num_shards)
28 | ]
29 |
30 | doc = []
31 | first_doc = [True]*args.num_shards
32 | def output_doc(i):
33 | if not first_doc[i]:
34 | outputs[i].write("\n")
35 | first_doc[i] = False
36 | for line in doc:
37 | outputs[i].write(line)
38 | doc.clear()
39 |
40 | num_docs = 0
41 | for line in h:
42 | if line.strip() == "": # empty line indicates new document
43 | output_doc(num_docs % args.num_shards)
44 | num_docs += 1
45 | else:
46 | doc.append(line)
47 | output_doc(num_docs % args.num_shards)
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
--------------------------------------------------------------------------------
/scripts/split_train_valid_docs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Split a large file into a train and valid set while respecting document
8 | boundaries. Documents should be separated by a single empty line.
9 | """
10 |
11 | import argparse
12 | import random
13 | import sys
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('input')
19 | parser.add_argument('sample_output', help='train output file')
20 | parser.add_argument('remainder_output', help='valid output file')
21 | parser.add_argument('-k', type=int, help="remainder size")
22 | parser.add_argument('--lines', action='store_true',
23 | help='split lines instead of docs')
24 | args = parser.parse_args()
25 |
26 | assert args.k is not None
27 |
28 | sample = []
29 | remainder = []
30 | num_docs = [0]
31 |
32 | def update_sample(doc):
33 | if len(sample) < args.k:
34 | sample.append(doc.copy())
35 | else:
36 | i = num_docs[0]
37 | j = random.randrange(i + 1)
38 | if j < args.k:
39 | remainder.append(sample[j])
40 | sample[j] = doc.copy()
41 | else:
42 | remainder.append(doc.copy())
43 | num_docs[0] += 1
44 | doc.clear()
45 |
46 | with open(args.input, 'r', encoding='utf-8') as h:
47 | doc = []
48 | for i, line in enumerate(h):
49 | if line.strip() == "": # empty line indicates new document
50 | update_sample(doc)
51 | else:
52 | doc.append(line)
53 | if args.lines:
54 | update_sample(doc)
55 | if i % 1000000 == 0:
56 | print(i, file=sys.stderr, end="", flush=True)
57 | elif i % 100000 == 0:
58 | print(".", file=sys.stderr, end="", flush=True)
59 | if len(doc) > 0:
60 | update_sample(doc)
61 | print(file=sys.stderr, flush=True)
62 |
63 | assert len(sample) == args.k
64 |
65 | with open(args.sample_output, 'w', encoding='utf-8') as out:
66 | first = True
67 | for doc in sample:
68 | if not first and not args.lines:
69 | out.write("\n")
70 | first = False
71 | for line in doc:
72 | out.write(line)
73 |
74 | with open(args.remainder_output, 'w', encoding='utf-8') as out:
75 | first = True
76 | for doc in remainder:
77 | if not first and not args.lines:
78 | out.write("\n")
79 | first = False
80 | for line in doc:
81 | out.write(line)
82 |
83 |
84 | if __name__ == '__main__':
85 | main()
86 |
--------------------------------------------------------------------------------
/scripts/spm_decode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import argparse
11 |
12 | import sentencepiece as spm
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--model", required=True,
18 | help="sentencepiece model to use for decoding")
19 | parser.add_argument("--input", required=True, help="input file to decode")
20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
21 | args = parser.parse_args()
22 |
23 | sp = spm.SentencePieceProcessor()
24 | sp.Load(args.model)
25 |
26 | if args.input_format == "piece":
27 | def decode(l):
28 | return "".join(sp.DecodePieces(l))
29 | elif args.input_format == "id":
30 | def decode(l):
31 | return "".join(sp.DecodeIds(l))
32 | else:
33 | raise NotImplementedError
34 |
35 | def tok2int(tok):
36 | # remap reference-side (represented as <>) to 0
37 | return int(tok) if tok != "<>" else 0
38 |
39 | with open(args.input, "r", encoding="utf-8") as h:
40 | for line in h:
41 | print(decode(list(map(tok2int, line.rstrip().split()))))
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/scripts/spm_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import sys
11 |
12 | import sentencepiece as spm
13 |
14 |
15 | if __name__ == "__main__":
16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
17 |
--------------------------------------------------------------------------------
/scripts/wav2vec_manifest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Data pre-processing: build vocabularies and binarize training data.
8 | """
9 |
10 | import argparse
11 | import glob
12 | import os
13 | import soundfile
14 | import random
15 |
16 |
17 | def get_parser():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
20 | parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
21 | help='percentage of data to use as validation set (between 0 and 1)')
22 | parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
23 | parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
24 | parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
25 | parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
26 | help='if set, path must contain this substring for a file to be included in the manifest')
27 | return parser
28 |
29 |
30 | def main(args):
31 | assert args.valid_percent >= 0 and args.valid_percent <= 1.
32 |
33 | dir_path = os.path.realpath(args.root)
34 | search_path = os.path.join(dir_path, '**/*.' + args.ext)
35 | rand = random.Random(args.seed)
36 |
37 | with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
38 | os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
39 | print(dir_path, file=train_f)
40 | print(dir_path, file=valid_f)
41 |
42 | for fname in glob.iglob(search_path, recursive=True):
43 | file_path = os.path.realpath(fname)
44 |
45 | if args.path_must_contain and args.path_must_contain not in file_path:
46 | continue
47 |
48 | frames = soundfile.info(fname).frames
49 | dest = train_f if rand.random() > args.valid_percent else valid_f
50 | print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
51 |
52 |
53 | if __name__ == '__main__':
54 | parser = get_parser()
55 | args = parser.parse_args()
56 | main(args)
57 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2,3
2 | model=424/test.log
3 | #model=test
4 |
5 | python3 train.py data-bin/424/baseline \
6 | --arch multilingual_transformer \
7 | --max-epoch 80 --fp16 \
8 | --encoder-langtok "tgt" \
9 | --restore-file /data/wanying/1.research/specific/checkpoints/424/baseline/checkpoint36.pt \
10 | --task multilingual_translation --lang-pairs it-en,ro-en,nl-en,it-ro,en-it,en-ro,en-nl,ro-it \
11 | --share-encoders --share-decoders \
12 | --share-all-embeddings --share-decoder-input-output-embed \
13 | --reset-lr-scheduler --reset-optimizer \
14 | --optimizer adam --adam-betas '(0.9, 0.98)' \
15 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
16 | --lr 0.0001 --min-lr 1e-09 --ddp-backend=no_c10d \
17 | --dropout 0.3 \
18 | --weight-decay 0.0 --clip-norm 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
19 | --max-tokens 4096 --update-freq 2 \
20 | --no-progress-bar --log-format json --log-interval 20 \
21 | --save-dir checkpoints/$model |tee -a logs/$model.log
22 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/NA-MNMT/120843696d6d9ae24b7ba12fbaf85158512ee714/tests/__init__.py
--------------------------------------------------------------------------------
/tests/speech_recognition/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/NA-MNMT/120843696d6d9ae24b7ba12fbaf85158512ee714/tests/speech_recognition/__init__.py
--------------------------------------------------------------------------------
/tests/speech_recognition/test_collaters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import unittest
8 |
9 | import numpy as np
10 | import torch
11 | from examples.speech_recognition.data.collaters import Seq2SeqCollater
12 |
13 |
14 | class TestSeq2SeqCollator(unittest.TestCase):
15 | def test_collate(self):
16 |
17 | eos_idx = 1
18 | pad_idx = 0
19 | collater = Seq2SeqCollater(
20 | feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx
21 | )
22 |
23 | # 2 frames in the first sample and 3 frames in the second one
24 | frames1 = np.array([[7, 8], [9, 10]])
25 | frames2 = np.array([[1, 2], [3, 4], [5, 6]])
26 | target1 = np.array([4, 2, 3, eos_idx])
27 | target2 = np.array([3, 2, eos_idx])
28 | sample1 = {"id": 0, "data": [frames1, target1]}
29 | sample2 = {"id": 1, "data": [frames2, target2]}
30 | batch = collater.collate([sample1, sample2])
31 |
32 | # collate sort inputs by frame's length before creating the batch
33 | self.assertTensorEqual(batch["id"], torch.tensor([1, 0]))
34 | self.assertEqual(batch["ntokens"], 7)
35 | self.assertTensorEqual(
36 | batch["net_input"]["src_tokens"],
37 | torch.tensor(
38 | [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]]
39 | ),
40 | )
41 | self.assertTensorEqual(
42 | batch["net_input"]["prev_output_tokens"],
43 | torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]),
44 | )
45 | self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2]))
46 | self.assertTensorEqual(
47 | batch["target"],
48 | torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]),
49 | )
50 | self.assertEqual(batch["nsentences"], 2)
51 |
52 | def assertTensorEqual(self, t1, t2):
53 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
54 | self.assertEqual(t1.ne(t2).long().sum(), 0)
55 |
56 |
57 | if __name__ == "__main__":
58 | unittest.main()
59 |
--------------------------------------------------------------------------------
/tests/speech_recognition/test_cross_entropy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
8 | from .asr_test_base import CrossEntropyCriterionTestBase
9 |
10 |
11 | class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
12 | def setUp(self):
13 | self.criterion_cls = CrossEntropyWithAccCriterion
14 | super().setUp()
15 |
16 | def test_cross_entropy_all_correct(self):
17 | sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
18 | loss, sample_size, logging_output = self.criterion(
19 | self.model, sample, "sum", log_probs=True
20 | )
21 | assert logging_output["correct"] == 20
22 | assert logging_output["total"] == 20
23 | assert logging_output["sample_size"] == 20
24 | assert logging_output["ntokens"] == 20
25 |
26 | def test_cross_entropy_all_wrong(self):
27 | sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
28 | loss, sample_size, logging_output = self.criterion(
29 | self.model, sample, "sum", log_probs=True
30 | )
31 | assert logging_output["correct"] == 0
32 | assert logging_output["total"] == 20
33 | assert logging_output["sample_size"] == 20
34 | assert logging_output["ntokens"] == 20
35 |
--------------------------------------------------------------------------------
/tests/test_character_token_embedder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import unittest
8 |
9 | from fairseq.data import Dictionary
10 | from fairseq.modules import CharacterTokenEmbedder
11 |
12 |
13 | class TestCharacterTokenEmbedder(unittest.TestCase):
14 | def test_character_token_embedder(self):
15 | vocab = Dictionary()
16 | vocab.add_symbol('hello')
17 | vocab.add_symbol('there')
18 |
19 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2)
20 |
21 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
22 | max_len = max(len(s) for s in test_sents)
23 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
24 | for i in range(len(test_sents)):
25 | input[i][0] = vocab.eos()
26 | for j in range(len(test_sents[i])):
27 | input[i][j + 1] = vocab.index(test_sents[i][j])
28 | input[i][j + 2] = vocab.eos()
29 | embs = embedder(input)
30 |
31 | assert embs.size() == (len(test_sents), max_len + 2, 5)
32 | self.assertAlmostEqual(embs[0][0], embs[1][0])
33 | self.assertAlmostEqual(embs[0][0], embs[0][-1])
34 | self.assertAlmostEqual(embs[0][1], embs[2][1])
35 | self.assertAlmostEqual(embs[0][3], embs[1][1])
36 |
37 | embs.sum().backward()
38 | assert embedder.char_embeddings.weight.grad is not None
39 |
40 | def assertAlmostEqual(self, t1, t2):
41 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
42 | self.assertLess((t1 - t2).abs().max(), 1e-6)
43 |
44 |
45 | if __name__ == '__main__':
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/tests/test_concat_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 |
8 | import torch
9 | from fairseq.data import LanguagePairDataset, TokenBlockDataset
10 | from fairseq.data.concat_dataset import ConcatDataset
11 | from tests.test_train import mock_dict
12 |
13 |
14 | class TestConcatDataset(unittest.TestCase):
15 | def setUp(self):
16 | d = mock_dict()
17 | tokens_1 = torch.LongTensor([1]).view(1, -1)
18 | tokens_ds1 = TokenBlockDataset(
19 | tokens_1,
20 | sizes=[tokens_1.size(-1)],
21 | block_size=1,
22 | pad=0,
23 | eos=1,
24 | include_targets=False,
25 | )
26 | self.dataset_1 = LanguagePairDataset(
27 | tokens_ds1, tokens_ds1.sizes, d, shuffle=False
28 | )
29 | tokens_2 = torch.LongTensor([2]).view(1, -1)
30 | tokens_ds2 = TokenBlockDataset(
31 | tokens_2,
32 | sizes=[tokens_2.size(-1)],
33 | block_size=1,
34 | pad=0,
35 | eos=1,
36 | include_targets=False,
37 | )
38 | self.dataset_2 = LanguagePairDataset(
39 | tokens_ds2, tokens_ds2.sizes, d, shuffle=False
40 | )
41 |
42 | def test_concat_dataset_basics(self):
43 | d = ConcatDataset(
44 | [self.dataset_1, self.dataset_2]
45 | )
46 | assert(len(d) == 2)
47 | assert(d[0]['source'][0] == 1)
48 | assert(d[1]['source'][0] == 2)
49 |
50 | d = ConcatDataset(
51 | [self.dataset_1, self.dataset_2], sample_ratios=[1, 2]
52 | )
53 | assert(len(d) == 3)
54 | assert(d[0]['source'][0] == 1)
55 | assert(d[1]['source'][0] == 2)
56 | assert(d[2]['source'][0] == 2)
57 |
58 | d = ConcatDataset(
59 | [self.dataset_1, self.dataset_2], sample_ratios=[2, 1]
60 | )
61 | assert(len(d) == 3)
62 | assert(d[0]['source'][0] == 1)
63 | assert(d[1]['source'][0] == 1)
64 | assert(d[2]['source'][0] == 2)
65 |
--------------------------------------------------------------------------------
/tests/test_convtbc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import unittest
8 | from fairseq.modules import ConvTBC
9 | import torch.nn as nn
10 |
11 |
12 | class TestConvTBC(unittest.TestCase):
13 |
14 | def test_convtbc(self):
15 | # ksz, in_channels, out_channels
16 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1)
17 | # out_channels, in_channels, ksz
18 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1)
19 |
20 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
21 | conv_tbc.bias.data.copy_(conv1d.bias.data)
22 |
23 | input_tbc = torch.randn(7, 2, 4, requires_grad=True)
24 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2)
25 | input1d.requires_grad = True
26 |
27 | output_tbc = conv_tbc(input_tbc)
28 | output1d = conv1d(input1d)
29 |
30 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
31 |
32 | grad_tbc = torch.randn(output_tbc.size())
33 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
34 |
35 | output_tbc.backward(grad_tbc)
36 | output1d.backward(grad1d)
37 |
38 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data)
39 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data)
40 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data)
41 |
42 | def assertAlmostEqual(self, t1, t2):
43 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
44 | self.assertLess((t1 - t2).abs().max(), 1e-4)
45 |
46 |
47 | if __name__ == '__main__':
48 | unittest.main()
49 |
--------------------------------------------------------------------------------
/tests/test_dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import tempfile
7 | import unittest
8 |
9 | import torch
10 |
11 | from fairseq.data import Dictionary
12 |
13 |
14 | class TestDictionary(unittest.TestCase):
15 |
16 | def test_finalize(self):
17 | txt = [
18 | 'A B C D',
19 | 'B C D',
20 | 'C D',
21 | 'D',
22 | ]
23 | ref_ids1 = list(map(torch.IntTensor, [
24 | [4, 5, 6, 7, 2],
25 | [5, 6, 7, 2],
26 | [6, 7, 2],
27 | [7, 2],
28 | ]))
29 | ref_ids2 = list(map(torch.IntTensor, [
30 | [7, 6, 5, 4, 2],
31 | [6, 5, 4, 2],
32 | [5, 4, 2],
33 | [4, 2],
34 | ]))
35 |
36 | # build dictionary
37 | d = Dictionary()
38 | for line in txt:
39 | d.encode_line(line, add_if_not_exist=True)
40 |
41 | def get_ids(dictionary):
42 | ids = []
43 | for line in txt:
44 | ids.append(dictionary.encode_line(line, add_if_not_exist=False))
45 | return ids
46 |
47 | def assertMatch(ids, ref_ids):
48 | for toks, ref_toks in zip(ids, ref_ids):
49 | self.assertEqual(toks.size(), ref_toks.size())
50 | self.assertEqual(0, (toks != ref_toks).sum().item())
51 |
52 | ids = get_ids(d)
53 | assertMatch(ids, ref_ids1)
54 |
55 | # check finalized dictionary
56 | d.finalize()
57 | finalized_ids = get_ids(d)
58 | assertMatch(finalized_ids, ref_ids2)
59 |
60 | # write to disk and reload
61 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
62 | d.save(tmp_dict.name)
63 | d = Dictionary.load(tmp_dict.name)
64 | reload_ids = get_ids(d)
65 | assertMatch(reload_ids, ref_ids2)
66 | assertMatch(finalized_ids, reload_ids)
67 |
68 |
69 | if __name__ == '__main__':
70 | unittest.main()
71 |
--------------------------------------------------------------------------------
/tests/test_iterators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 |
8 | from fairseq.data import iterators
9 |
10 |
11 | class TestIterators(unittest.TestCase):
12 |
13 | def test_counting_iterator(self):
14 | x = list(range(10))
15 | itr = iterators.CountingIterator(x)
16 | self.assertTrue(itr.has_next())
17 | self.assertEqual(next(itr), 0)
18 | self.assertEqual(next(itr), 1)
19 | itr.skip(3)
20 | self.assertEqual(next(itr), 5)
21 | itr.skip(3)
22 | self.assertEqual(next(itr), 9)
23 | self.assertFalse(itr.has_next())
24 |
25 |
26 | if __name__ == '__main__':
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/tests/test_memory_efficient_fp16.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import argparse
7 | import unittest
8 |
9 | import torch
10 |
11 | from fairseq.optim.adam import FairseqAdam
12 | from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
13 |
14 |
15 | class TestMemoryEfficientFP16(unittest.TestCase):
16 |
17 | def test_load_state_dict(self):
18 | # define simple FP16 model
19 | model = torch.nn.Linear(5, 5).cuda().half()
20 | params = list(model.parameters())
21 |
22 | # initialize memory efficient FP16 optimizer
23 | optimizer = FairseqAdam(
24 | argparse.Namespace(
25 | lr=[0.00001],
26 | adam_betas='(0.9, 0.999)',
27 | adam_eps=1e-8,
28 | weight_decay=0.0,
29 | ),
30 | params,
31 | )
32 | me_optimizer = MemoryEfficientFP16Optimizer(
33 | argparse.Namespace(
34 | fp16_init_scale=1,
35 | fp16_scale_window=1,
36 | fp16_scale_tolerance=1,
37 | threshold_loss_scale=1,
38 | ),
39 | params,
40 | optimizer,
41 | )
42 |
43 | # optimizer state is created in the first step
44 | loss = model(torch.rand(5).cuda().half()).sum()
45 | me_optimizer.backward(loss)
46 | me_optimizer.step()
47 |
48 | # reload state
49 | state = me_optimizer.state_dict()
50 | me_optimizer.load_state_dict(state)
51 | for k, v in me_optimizer.optimizer.state.items():
52 | self.assertTrue(k.dtype == torch.float16)
53 | for v_i in v.values():
54 | if torch.is_tensor(v_i):
55 | self.assertTrue(v_i.dtype == torch.float32)
56 |
57 |
58 | if __name__ == '__main__':
59 | unittest.main()
60 |
--------------------------------------------------------------------------------
/tests/test_sparse_multihead_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import unittest
8 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
9 |
10 |
11 | class TestSparseMultiheadAttention(unittest.TestCase):
12 | def test_sparse_multihead_attention(self):
13 | attn_weights = torch.randn(1, 8, 8)
14 | bidirectional_sparse_mask = torch.tensor([
15 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
16 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
17 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
18 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
19 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
20 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
21 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
22 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
23 | ])
24 |
25 | bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
26 | bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
27 | torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))
28 |
29 | sparse_mask = torch.tensor([
30 | [0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
31 | float('-inf'), float('-inf')],
32 | [0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
33 | [0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
34 | [0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
35 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
36 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
37 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
38 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
39 | ])
40 |
41 | attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
42 | attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
43 |
44 | torch.all(torch.eq(attention_sparse_mask, sparse_mask))
45 |
46 |
47 | if __name__ == '__main__':
48 | unittest.main()
49 |
--------------------------------------------------------------------------------
/tests/test_token_block_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 |
8 | import torch
9 |
10 | from fairseq.data import TokenBlockDataset
11 |
12 | import tests.utils as test_utils
13 |
14 |
15 | class TestTokenBlockDataset(unittest.TestCase):
16 |
17 | def _build_dataset(self, data, **kwargs):
18 | sizes = [len(x) for x in data]
19 | underlying_ds = test_utils.TestDataset(data)
20 | return TokenBlockDataset(underlying_ds, sizes, **kwargs)
21 |
22 | def test_eos_break_mode(self):
23 | data = [
24 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
25 | torch.tensor([1], dtype=torch.long),
26 | torch.tensor([8, 7, 6, 1], dtype=torch.long),
27 | ]
28 | ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
29 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
30 | self.assertEqual(ds[1].tolist(), [1])
31 | self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
32 |
33 | data = [
34 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
35 | torch.tensor([8, 7, 6, 1], dtype=torch.long),
36 | torch.tensor([1], dtype=torch.long),
37 | ]
38 | ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
39 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
40 | self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
41 | self.assertEqual(ds[2].tolist(), [1])
42 |
43 | def test_block_break_mode(self):
44 | data = [
45 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
46 | torch.tensor([8, 7, 6, 1], dtype=torch.long),
47 | torch.tensor([9, 1], dtype=torch.long),
48 | ]
49 | ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
50 | self.assertEqual(ds[0].tolist(), [5, 4, 3])
51 | self.assertEqual(ds[1].tolist(), [2, 1, 8])
52 | self.assertEqual(ds[2].tolist(), [7, 6, 1])
53 | self.assertEqual(ds[3].tolist(), [9, 1])
54 |
55 | def test_complete_break_mode(self):
56 | data = [
57 | torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
58 | torch.tensor([8, 7, 6, 1], dtype=torch.long),
59 | torch.tensor([9, 1], dtype=torch.long),
60 | ]
61 | ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
62 | self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
63 | self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
64 |
65 | data = [
66 | torch.tensor([4, 3, 2, 1], dtype=torch.long),
67 | torch.tensor([5, 1], dtype=torch.long),
68 | torch.tensor([1], dtype=torch.long),
69 | torch.tensor([6, 1], dtype=torch.long),
70 | ]
71 | ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
72 | self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
73 | self.assertEqual(ds[1].tolist(), [5, 1, 1])
74 | self.assertEqual(ds[2].tolist(), [6, 1])
75 |
76 |
77 | if __name__ == "__main__":
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 |
8 | import torch
9 |
10 | from fairseq import utils
11 |
12 |
13 | class TestUtils(unittest.TestCase):
14 |
15 | def test_convert_padding_direction(self):
16 | pad = 1
17 | left_pad = torch.LongTensor([
18 | [2, 3, 4, 5, 6],
19 | [1, 7, 8, 9, 10],
20 | [1, 1, 1, 11, 12],
21 | ])
22 | right_pad = torch.LongTensor([
23 | [2, 3, 4, 5, 6],
24 | [7, 8, 9, 10, 1],
25 | [11, 12, 1, 1, 1],
26 | ])
27 |
28 | self.assertAlmostEqual(
29 | right_pad,
30 | utils.convert_padding_direction(
31 | left_pad,
32 | pad,
33 | left_to_right=True,
34 | ),
35 | )
36 | self.assertAlmostEqual(
37 | left_pad,
38 | utils.convert_padding_direction(
39 | right_pad,
40 | pad,
41 | right_to_left=True,
42 | ),
43 | )
44 |
45 | def test_make_positions(self):
46 | pad = 1
47 | left_pad_input = torch.LongTensor([
48 | [9, 9, 9, 9, 9],
49 | [1, 9, 9, 9, 9],
50 | [1, 1, 1, 9, 9],
51 | ])
52 | left_pad_output = torch.LongTensor([
53 | [2, 3, 4, 5, 6],
54 | [1, 2, 3, 4, 5],
55 | [1, 1, 1, 2, 3],
56 | ])
57 | right_pad_input = torch.LongTensor([
58 | [9, 9, 9, 9, 9],
59 | [9, 9, 9, 9, 1],
60 | [9, 9, 1, 1, 1],
61 | ])
62 | right_pad_output = torch.LongTensor([
63 | [2, 3, 4, 5, 6],
64 | [2, 3, 4, 5, 1],
65 | [2, 3, 1, 1, 1],
66 | ])
67 |
68 | self.assertAlmostEqual(
69 | left_pad_output,
70 | utils.make_positions(left_pad_input, pad),
71 | )
72 | self.assertAlmostEqual(
73 | right_pad_output,
74 | utils.make_positions(right_pad_input, pad),
75 | )
76 |
77 | def assertAlmostEqual(self, t1, t2):
78 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
79 | self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
80 |
81 |
82 | if __name__ == '__main__':
83 | unittest.main()
84 |
--------------------------------------------------------------------------------