├── .gitignore ├── .gitmodules ├── LICENSE ├── LICENSE.fairseq ├── README.md ├── fairseq ├── __init__.py ├── benchmark │ ├── __init__.py │ ├── dummy_dataset.py │ ├── dummy_lm.py │ ├── dummy_masked_lm.py │ ├── dummy_model.py │ └── dummy_mt.py ├── binarizer.py ├── checkpoint_utils.py ├── clib │ ├── cuda │ │ ├── ngram_repeat_block_cuda.cpp │ │ └── ngram_repeat_block_cuda_kernel.cu │ ├── libbase │ │ └── balanced_assignment.cpp │ ├── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp │ ├── libnat │ │ └── edit_dist.cpp │ └── libnat_cuda │ │ ├── binding.cpp │ │ ├── edit_dist.cu │ │ └── edit_dist.h ├── config │ ├── __init__.py │ ├── config.yaml │ └── model │ │ ├── transformer_lm │ │ ├── transformer_lm_baevski_gbw.yaml │ │ ├── transformer_lm_baevski_wiki103.yaml │ │ ├── transformer_lm_big.yaml │ │ ├── transformer_lm_gbw.yaml │ │ ├── transformer_lm_gpt.yaml │ │ ├── transformer_lm_gpt2_big.yaml │ │ ├── transformer_lm_gpt2_medium.yaml │ │ ├── transformer_lm_gpt2_small.yaml │ │ └── transformer_lm_wiki103.yaml │ │ ├── wav2vec │ │ └── vq_wav2vec_gumbel.yaml │ │ └── wav2vec2 │ │ ├── wav2vec2_base.yaml │ │ └── wav2vec2_large.yaml ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── composite_loss.py │ ├── cross_entropy.py │ ├── ctc.py │ ├── fairseq_criterion.py │ ├── fastspeech2_loss.py │ ├── hubert_criterion.py │ ├── label_smoothed_cross_entropy.py │ ├── label_smoothed_cross_entropy_latency_augmented.py │ ├── label_smoothed_cross_entropy_with_alignment.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── model_criterion.py │ ├── nat_loss.py │ ├── sentence_prediction.py │ ├── sentence_ranking.py │ ├── speech_to_speech_criterion.py │ ├── speech_ulm_criterion.py │ ├── tacotron2_loss.py │ └── wav2vec_criterion.py ├── data │ ├── __init__.py │ ├── add_target_dataset.py │ ├── append_token_dataset.py │ ├── audio │ │ ├── __init__.py │ │ ├── audio_utils.py │ │ ├── data_cfg.py │ │ ├── feature_transforms │ │ │ ├── __init__.py │ │ │ ├── delta_deltas.py │ │ │ ├── global_cmvn.py │ │ │ ├── specaugment.py │ │ │ └── utterance_cmvn.py │ │ ├── frm_text_to_speech_dataset.py │ │ ├── hubert_dataset.py │ │ ├── multi_modality_dataset.py │ │ ├── raw_audio_dataset.py │ │ ├── speech_to_speech_dataset.py │ │ ├── speech_to_text_dataset.py │ │ ├── speech_to_text_joint_dataset.py │ │ └── text_to_speech_dataset.py │ ├── backtranslation_dataset.py │ ├── base_wrapper_dataset.py │ ├── bucket_pad_length_dataset.py │ ├── codedataset.py │ ├── colorize_dataset.py │ ├── concat_dataset.py │ ├── concat_sentences_dataset.py │ ├── data_utils.py │ ├── data_utils_fast.pyx │ ├── denoising_dataset.py │ ├── dictionary.py │ ├── encoders │ │ ├── __init__.py │ │ ├── byte_bpe.py │ │ ├── byte_utils.py │ │ ├── bytes.py │ │ ├── characters.py │ │ ├── fastbpe.py │ │ ├── gpt2_bpe.py │ │ ├── gpt2_bpe_utils.py │ │ ├── hf_bert_bpe.py │ │ ├── hf_byte_bpe.py │ │ ├── moses_tokenizer.py │ │ ├── nltk_tokenizer.py │ │ ├── sentencepiece_bpe.py │ │ ├── space_tokenizer.py │ │ ├── subword_nmt_bpe.py │ │ └── utils.py │ ├── fairseq_dataset.py │ ├── fasta_dataset.py │ ├── huffman │ │ ├── __init__.py │ │ ├── huffman_coder.py │ │ └── huffman_mmap_indexed_dataset.py │ ├── id_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── legacy │ │ ├── __init__.py │ │ ├── block_pair_dataset.py │ │ ├── masked_lm_dataset.py │ │ └── masked_lm_dictionary.py │ ├── list_dataset.py │ ├── lm_context_window_dataset.py │ ├── lru_cache_dataset.py │ ├── mask_tokens_dataset.py │ ├── monolingual_dataset.py │ ├── multi_corpus_dataset.py │ ├── multi_corpus_sampled_dataset.py │ ├── multilingual │ │ ├── __init__.py │ │ ├── multilingual_data_manager.py │ │ ├── multilingual_utils.py │ │ ├── sampled_multi_dataset.py │ │ ├── sampled_multi_epoch_dataset.py │ │ └── sampling_method.py │ ├── nested_dictionary_dataset.py │ ├── noising.py │ ├── num_samples_dataset.py │ ├── numel_dataset.py │ ├── offset_tokens_dataset.py │ ├── pad_dataset.py │ ├── plasma_utils.py │ ├── prepend_dataset.py │ ├── prepend_token_dataset.py │ ├── raw_label_dataset.py │ ├── replace_dataset.py │ ├── resampling_dataset.py │ ├── roll_dataset.py │ ├── round_robin_zip_datasets.py │ ├── shorten_dataset.py │ ├── sort_dataset.py │ ├── strip_token_dataset.py │ ├── subsample_dataset.py │ ├── text_compressor.py │ ├── token_block_dataset.py │ ├── token_block_utils_fast.pyx │ ├── transform_eos_concat_langpair_dataset.py │ ├── transform_eos_dataset.py │ └── transform_eos_lang_pair_dataset.py ├── dataclass │ ├── __init__.py │ ├── configs.py │ ├── constants.py │ ├── initialize.py │ └── utils.py ├── distributed │ ├── __init__.py │ ├── distributed_timeout_wrapper.py │ ├── fully_sharded_data_parallel.py │ ├── legacy_distributed_data_parallel.py │ ├── module_proxy_wrapper.py │ ├── tpu_distributed_data_parallel.py │ └── utils.py ├── file_chunker_utils.py ├── file_io.py ├── file_utils.py ├── hub_utils.py ├── incremental_decoding_utils.py ├── iterative_refinement_generator.py ├── logging │ ├── __init__.py │ ├── meters.py │ ├── metrics.py │ └── progress_bar.py ├── model_parallel │ ├── __init__.py │ ├── criterions │ │ ├── __init__.py │ │ └── vocab_parallel_cross_entropy.py │ ├── megatron_trainer.py │ ├── models │ │ ├── __init__.py │ │ ├── pipeline_parallel_transformer │ │ │ ├── __init__.py │ │ │ ├── layers.py │ │ │ └── model.py │ │ ├── roberta │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── transformer.py │ │ └── transformer_lm.py │ └── modules │ │ ├── __init__.py │ │ ├── multihead_attention.py │ │ └── transformer_layer.py ├── models │ ├── __init__.py │ ├── bart │ │ ├── __init__.py │ │ ├── hub_interface.py │ │ └── model.py │ ├── composite_encoder.py │ ├── distributed_fairseq_model.py │ ├── ema │ │ ├── __init__.py │ │ └── ema.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_lm.py │ ├── fconv_self_att.py │ ├── hubert │ │ ├── __init__.py │ │ ├── hubert.py │ │ └── hubert_asr.py │ ├── huggingface │ │ ├── __init__.py │ │ └── hf_gpt2.py │ ├── lightconv.py │ ├── lightconv_lm.py │ ├── lstm.py │ ├── lstm_lm.py │ ├── masked_lm.py │ ├── model_utils.py │ ├── multilingual_transformer.py │ ├── nat │ │ ├── __init__.py │ │ ├── cmlm_transformer.py │ │ ├── fairseq_nat_model.py │ │ ├── insertion_transformer.py │ │ ├── iterative_nonautoregressive_transformer.py │ │ ├── levenshtein_transformer.py │ │ ├── levenshtein_utils.py │ │ ├── nat_crf_transformer.py │ │ ├── nonautoregressive_ensembles.py │ │ └── nonautoregressive_transformer.py │ ├── roberta │ │ ├── __init__.py │ │ ├── alignment_utils.py │ │ ├── enc_dec.py │ │ ├── hub_interface.py │ │ ├── model.py │ │ ├── model_camembert.py │ │ ├── model_gottbert.py │ │ └── model_xlmr.py │ ├── speech_to_speech │ │ ├── __init__.py │ │ ├── modules.py │ │ └── s2s_transformer.py │ ├── speech_to_text │ │ ├── __init__.py │ │ ├── berard.py │ │ ├── convtransformer.py │ │ ├── hub_interface.py │ │ ├── modules │ │ │ ├── augmented_memory_attention.py │ │ │ └── emformer.py │ │ ├── s2t_conformer.py │ │ ├── s2t_transformer.py │ │ ├── utils.py │ │ └── xm_transformer.py │ ├── text_to_speech │ │ ├── __init__.py │ │ ├── codehifigan.py │ │ ├── fastspeech2.py │ │ ├── hifigan.py │ │ ├── hub_interface.py │ │ ├── tacotron2.py │ │ ├── tts_transformer.py │ │ └── vocoder.py │ ├── transformer │ │ ├── __init__.py │ │ ├── transformer_base.py │ │ ├── transformer_config.py │ │ ├── transformer_decoder.py │ │ ├── transformer_encoder.py │ │ └── transformer_legacy.py │ ├── transformer_align.py │ ├── transformer_from_pretrained_xlm.py │ ├── transformer_lm.py │ ├── transformer_ulm.py │ └── wav2vec │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── wav2vec.py │ │ ├── wav2vec2.py │ │ └── wav2vec2_asr.py ├── modules │ ├── __init__.py │ ├── adaptive_input.py │ ├── adaptive_softmax.py │ ├── base_layer.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── checkpoint_activations.py │ ├── conformer_layer.py │ ├── conv_tbc.py │ ├── cross_entropy.py │ ├── cuda_utils.cu │ ├── downsampled_multihead_attention.py │ ├── dynamic_convolution.py │ ├── dynamic_crf_layer.py │ ├── dynamicconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── dynamicconv_cuda.cpp │ │ ├── dynamicconv_cuda.cuh │ │ ├── dynamicconv_cuda_kernel.cu │ │ ├── dynamicconv_layer.py │ │ ├── dynamiconv_cpu.cpp │ │ └── setup.py │ ├── ema_module.py │ ├── espnet_multihead_attention.py │ ├── fairseq_dropout.py │ ├── fp32_batch_norm.py │ ├── fp32_group_norm.py │ ├── fp32_instance_norm.py │ ├── gelu.py │ ├── grad_multiply.py │ ├── gumbel_vector_quantizer.py │ ├── kmeans_attention.py │ ├── kmeans_vector_quantizer.py │ ├── layer_drop.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── lightconv_cuda.cpp │ │ ├── lightconv_cuda.cuh │ │ ├── lightconv_cuda_kernel.cu │ │ ├── lightconv_layer.py │ │ └── setup.py │ ├── lightweight_convolution.py │ ├── linearized_convolution.py │ ├── location_attention.py │ ├── lstm_cell_with_zoneout.py │ ├── multihead_attention.py │ ├── positional_embedding.py │ ├── positional_encoding.py │ ├── quant_noise.py │ ├── quantization │ │ ├── __init__.py │ │ ├── pq │ │ │ ├── __init__.py │ │ │ ├── em.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── qconv.py │ │ │ │ ├── qemb.py │ │ │ │ └── qlinear.py │ │ │ ├── pq.py │ │ │ └── utils.py │ │ ├── quantization_options.py │ │ └── scalar │ │ │ ├── __init__.py │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── qact.py │ │ │ ├── qconv.py │ │ │ ├── qemb.py │ │ │ └── qlinear.py │ │ │ ├── ops.py │ │ │ └── utils.py │ ├── rotary_positional_embedding.py │ ├── same_pad.py │ ├── scalar_bias.py │ ├── sinusoidal_positional_embedding.py │ ├── sparse_multihead_attention.py │ ├── sparse_transformer_sentence_encoder.py │ ├── sparse_transformer_sentence_encoder_layer.py │ ├── transformer_layer.py │ ├── transformer_sentence_encoder.py │ ├── transformer_sentence_encoder_layer.py │ ├── transpose_last.py │ ├── unfold.py │ └── vggblock.py ├── nan_detector.py ├── ngram_repeat_block.py ├── optim │ ├── __init__.py │ ├── adadelta.py │ ├── adafactor.py │ ├── adagrad.py │ ├── adam.py │ ├── adamax.py │ ├── amp_optimizer.py │ ├── bmuf.py │ ├── composite.py │ ├── cpu_adam.py │ ├── dynamic_loss_scaler.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── fused_adam.py │ ├── fused_lamb.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── manual_lr_scheduler.py │ │ ├── pass_through.py │ │ ├── polynomial_decay_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ ├── step_lr_scheduler.py │ │ ├── tri_stage_lr_scheduler.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ ├── sgd.py │ └── shard.py ├── options.py ├── pdb.py ├── quantization_utils.py ├── registry.py ├── scoring │ ├── __init__.py │ ├── bertscore.py │ ├── bleu.py │ ├── chrf.py │ ├── meteor.py │ ├── tokenizer.py │ └── wer.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── speech_generator.py ├── tasks │ ├── __init__.py │ ├── audio_finetuning.py │ ├── audio_pretraining.py │ ├── cross_lingual_lm.py │ ├── denoising.py │ ├── fairseq_task.py │ ├── frm_text_to_speech.py │ ├── hubert_pretraining.py │ ├── language_modeling.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── multilingual_denoising.py │ ├── multilingual_language_modeling.py │ ├── multilingual_masked_lm.py │ ├── multilingual_translation.py │ ├── online_backtranslation.py │ ├── semisupervised_translation.py │ ├── sentence_prediction.py │ ├── sentence_ranking.py │ ├── simultaneous_translation.py │ ├── speech_to_speech.py │ ├── speech_to_text.py │ ├── speech_ulm_task.py │ ├── text_to_speech.py │ ├── translation.py │ ├── translation_from_pretrained_bart.py │ ├── translation_from_pretrained_xlm.py │ ├── translation_lev.py │ └── translation_multi_simple_epoch.py ├── token_generation_constraints.py ├── tokenizer.py ├── trainer.py ├── utils.py └── version.txt ├── fairseq_cli ├── __init__.py ├── eval_lm.py ├── generate.py ├── hydra_train.py ├── interactive.py ├── preprocess.py ├── score.py ├── train.py └── validate.py ├── fs_plugins ├── __init__.py ├── criterions │ ├── __init__.py │ ├── nat_dag_loss.py │ ├── nat_dag_loss_ngram.py │ ├── pass_prob.py │ └── utilities.py ├── custom_ops │ ├── __init__.py │ ├── dag_best_alignment.cu │ ├── dag_loss.cpp │ ├── dag_loss.cu │ ├── dag_loss.py │ ├── logsoftmax_gather.cu │ └── utilities.h ├── models │ ├── __init__.py │ ├── glat_decomposed_with_link.py │ ├── ls_glat_decomposed_with_link.py │ ├── ls_nat_decoder.py │ └── ls_transformer.py ├── optimizer │ ├── __init__.py │ └── ls_adam.py ├── scripts │ ├── average_checkpoints.py │ ├── convert_ls_to_fairseq.py │ └── test_tradeoff.py └── tasks │ ├── __init__.py │ └── translation_lev_modified.py ├── hubconf.py ├── model.png ├── scripts ├── __init__.py ├── average_checkpoints.py ├── build_sym_alignment.py ├── compare_namespaces.py ├── compound_split_bleu.sh ├── constraints │ ├── extract.py │ └── validate.py ├── convert_dictionary.lua ├── convert_model.lua ├── count_docs.py ├── read_binarized.py ├── rm_pt.py ├── sacrebleu.sh ├── shard_docs.py ├── split_train_valid_docs.py ├── spm_decode.py ├── spm_encode.py ├── spm_train.py └── test_fsdp.sh ├── setup.cfg ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBrains PyCharm IDE 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # macOS dir files 13 | .DS_Store 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # Checkpoints 35 | checkpoints 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | # Generated files 113 | /fairseq/temporal_convolution_tbc 114 | /fairseq/modules/*_layer/*_forward.cu 115 | /fairseq/modules/*_layer/*_backward.cu 116 | /fairseq/version.py 117 | 118 | # data 119 | data-bin/ 120 | 121 | # reranking 122 | /examples/reranking/rerank_data 123 | 124 | # Cython-generated C++ source files 125 | /fairseq/data/data_utils_fast.cpp 126 | /fairseq/data/token_block_utils_fast.cpp 127 | 128 | # VSCODE 129 | .vscode/ftp-sync.json 130 | .vscode/settings.json 131 | 132 | # Experimental Folder 133 | experimental/* 134 | 135 | # Weights and Biases logs 136 | wandb/ 137 | 138 | git_info 139 | current_revision 140 | .vscode 141 | 142 | /cod/tmp -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dag_search"] 2 | path = dag_search 3 | url = https://github.com/thu-coai/DAG-Search 4 | 5 | [submodule "lightseq"] 6 | path = lightseq 7 | url = https://github.com/thu-coai/lightseq-nat 8 | 9 | [submodule "cub"] 10 | path = cub 11 | url = https://github.com/NVIDIA/cub -------------------------------------------------------------------------------- /LICENSE.fairseq: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import os 8 | import sys 9 | 10 | try: 11 | from .version import __version__ # noqa 12 | except ImportError: 13 | version_txt = os.path.join(os.path.dirname(__file__), "version.txt") 14 | with open(version_txt) as f: 15 | __version__ = f.read().strip() 16 | 17 | __all__ = ["pdb"] 18 | 19 | # backwards compatibility to support `from fairseq.X import Y` 20 | from fairseq.distributed import utils as distributed_utils 21 | from fairseq.logging import meters, metrics, progress_bar # noqa 22 | 23 | sys.modules["fairseq.distributed_utils"] = distributed_utils 24 | sys.modules["fairseq.meters"] = meters 25 | sys.modules["fairseq.metrics"] = metrics 26 | sys.modules["fairseq.progress_bar"] = progress_bar 27 | 28 | # initialize hydra 29 | from fairseq.dataclass.initialize import hydra_init 30 | 31 | hydra_init() 32 | 33 | import fairseq.criterions # noqa 34 | import fairseq.distributed # noqa 35 | import fairseq.models # noqa 36 | import fairseq.modules # noqa 37 | import fairseq.optim # noqa 38 | import fairseq.optim.lr_scheduler # noqa 39 | import fairseq.pdb # noqa 40 | import fairseq.scoring # noqa 41 | import fairseq.tasks # noqa 42 | import fairseq.token_generation_constraints # noqa 43 | 44 | import fairseq.benchmark # noqa 45 | import fairseq.model_parallel # noqa 46 | -------------------------------------------------------------------------------- /fairseq/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # import models/tasks to register them 7 | from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa 8 | -------------------------------------------------------------------------------- /fairseq/benchmark/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fairseq.data import FairseqDataset 3 | 4 | 5 | class DummyDataset(FairseqDataset): 6 | def __init__(self, batch, num_items, item_size): 7 | super().__init__() 8 | self.batch = batch 9 | self.num_items = num_items 10 | self.item_size = item_size 11 | 12 | def __getitem__(self, index): 13 | return index 14 | 15 | def __len__(self): 16 | return self.num_items 17 | 18 | def collater(self, samples): 19 | return self.batch 20 | 21 | @property 22 | def sizes(self): 23 | return np.array([self.item_size] * self.num_items) 24 | 25 | def num_tokens(self, index): 26 | return self.item_size 27 | 28 | def size(self, index): 29 | return self.item_size 30 | 31 | def ordered_indices(self): 32 | return np.arange(self.num_items) 33 | 34 | @property 35 | def supports_prefetch(self): 36 | return False 37 | -------------------------------------------------------------------------------- /fairseq/clib/cuda/ngram_repeat_block_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT License. 4 | */ 5 | 6 | #include 7 | #include 8 | 9 | /* 10 | CPP Binding for CUDA OP 11 | */ 12 | 13 | // CUDA forward declarations 14 | torch::Tensor ngram_repeat_block_cuda_forward( 15 | torch::Tensor tokens, 16 | torch::Tensor lprobs, 17 | int bsz, 18 | int step, 19 | int beam_size, 20 | int no_repeat_ngram_size); 21 | 22 | #define CHECK_CUDA(x) \ 23 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) \ 25 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 26 | #define CHECK_INPUT(x) \ 27 | CHECK_CUDA(x); \ 28 | CHECK_CONTIGUOUS(x) 29 | 30 | // Input check and call to CUDA OP 31 | // Backward method not required 32 | torch::Tensor ngram_repeat_block_forward( 33 | torch::Tensor tokens, 34 | torch::Tensor lprobs, 35 | int bsz, 36 | int step, 37 | int beam_size, 38 | int no_repeat_ngram_size) { 39 | CHECK_INPUT(tokens); 40 | CHECK_INPUT(lprobs); 41 | assert(bsz > 0); 42 | assert(step >= 0); 43 | assert(beam_size > 0); 44 | assert(no_repeat_ngram_size > 0); 45 | 46 | return ngram_repeat_block_cuda_forward( 47 | tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size); 48 | } 49 | 50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 51 | m.def( 52 | "forward", 53 | &ngram_repeat_block_forward, 54 | "No Repeat Ngram Block forward (CUDA)"); 55 | } 56 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT 12 | 13 | static struct PyModuleDef module_def = { 14 | PyModuleDef_HEAD_INIT, 15 | "libbleu", /* name of module */ 16 | // NOLINTNEXTLINE 17 | NULL, /* module documentation, may be NULL */ 18 | -1, /* size of per-interpreter state of the module, 19 | or -1 if the module keeps state in global variables. */ 20 | method_def}; // NOLINT 21 | 22 | #if PY_MAJOR_VERSION == 2 23 | PyMODINIT_FUNC init_libbleu() 24 | #else 25 | PyMODINIT_FUNC PyInit_libbleu() 26 | #endif 27 | { 28 | PyObject* m = PyModule_Create(&module_def); 29 | if (!m) { 30 | return NULL; 31 | } 32 | return m; 33 | } 34 | -------------------------------------------------------------------------------- /fairseq/clib/libnat_cuda/binding.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /* 10 | This code is partially adpoted from 11 | https://github.com/1ytic/pytorch-edit-distance 12 | */ 13 | 14 | #include 15 | #include "edit_dist.h" 16 | 17 | #ifndef TORCH_CHECK 18 | #define TORCH_CHECK AT_CHECK 19 | #endif 20 | 21 | #define CHECK_CUDA(x) \ 22 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 23 | #define CHECK_CONTIGUOUS(x) \ 24 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) \ 26 | CHECK_CUDA(x); \ 27 | CHECK_CONTIGUOUS(x) 28 | 29 | torch::Tensor LevenshteinDistance( 30 | torch::Tensor source, 31 | torch::Tensor target, 32 | torch::Tensor source_length, 33 | torch::Tensor target_length) { 34 | CHECK_INPUT(source); 35 | CHECK_INPUT(target); 36 | CHECK_INPUT(source_length); 37 | CHECK_INPUT(target_length); 38 | return LevenshteinDistanceCuda(source, target, source_length, target_length); 39 | } 40 | 41 | torch::Tensor GenerateDeletionLabel( 42 | torch::Tensor source, 43 | torch::Tensor operations) { 44 | CHECK_INPUT(source); 45 | CHECK_INPUT(operations); 46 | return GenerateDeletionLabelCuda(source, operations); 47 | } 48 | 49 | std::pair GenerateInsertionLabel( 50 | torch::Tensor target, 51 | torch::Tensor operations) { 52 | CHECK_INPUT(target); 53 | CHECK_INPUT(operations); 54 | return GenerateInsertionLabelCuda(target, operations); 55 | } 56 | 57 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 58 | m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); 59 | m.def( 60 | "generate_deletion_labels", 61 | &GenerateDeletionLabel, 62 | "Generate Deletion Label"); 63 | m.def( 64 | "generate_insertion_labels", 65 | &GenerateInsertionLabel, 66 | "Generate Insertion Label"); 67 | } 68 | -------------------------------------------------------------------------------- /fairseq/clib/libnat_cuda/edit_dist.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | torch::Tensor LevenshteinDistanceCuda( 14 | torch::Tensor source, 15 | torch::Tensor target, 16 | torch::Tensor source_length, 17 | torch::Tensor target_length); 18 | 19 | torch::Tensor GenerateDeletionLabelCuda( 20 | torch::Tensor source, 21 | torch::Tensor operations); 22 | 23 | std::pair GenerateInsertionLabelCuda( 24 | torch::Tensor source, 25 | torch::Tensor operations); 26 | -------------------------------------------------------------------------------- /fairseq/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /fairseq/config/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | hydra: 4 | run: 5 | dir: . 6 | 7 | defaults: 8 | - _self_ 9 | - task: null 10 | - model: null 11 | - criterion: cross_entropy 12 | - optimizer: null 13 | - lr_scheduler: fixed 14 | - bpe: null 15 | - tokenizer: null 16 | - scoring: null 17 | - generation: null 18 | - common_eval: null 19 | - eval_lm: null 20 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.3 4 | attention_dropout: 0.1 5 | activation_dropout: 0.1 6 | relu_dropout: 0.1 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 16 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: "20000,60000" 16 | adaptive_softmax_dropout: 0.2 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: true 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: "20000,60000" 27 | tie_adaptive_weights: true 28 | tie_adaptive_proj: true 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.0 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 768 8 | decoder_output_dim: 768 9 | decoder_input_dim: 768 10 | decoder_ffn_embed_dim: 3072 11 | decoder_layers: 12 12 | decoder_attention_heads: 12 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1600 8 | decoder_output_dim: 1600 9 | decoder_input_dim: 1600 10 | decoder_ffn_embed_dim: 6400 11 | decoder_layers: 48 12 | decoder_attention_heads: 25 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1280 8 | decoder_output_dim: 1280 9 | decoder_input_dim: 1280 10 | decoder_ffn_embed_dim: 5120 11 | decoder_layers: 36 12 | decoder_attention_heads: 20 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 24 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.3 4 | attention_dropout: 0.1 5 | activation_dropout: 0.1 6 | relu_dropout: 0.1 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 16 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: "20000,60000" 16 | adaptive_softmax_dropout: 0.2 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: true 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: "20000,60000" 27 | tie_adaptive_weights: true 28 | tie_adaptive_proj: true 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation: gelu 3 | vq_type: gumbel 4 | vq_depth: 2 5 | combine_groups: true 6 | -------------------------------------------------------------------------------- /fairseq/config/model/wav2vec2/wav2vec2_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | quantize_targets: true 4 | final_dim: 256 5 | encoder_layerdrop: 0.05 6 | dropout_input: 0.1 7 | dropout_features: 0.1 8 | feature_grad_mult: 0.1 9 | -------------------------------------------------------------------------------- /fairseq/config/model/wav2vec2/wav2vec2_large.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | quantize_targets: true 4 | extractor_mode: layer_norm 5 | layer_norm_first: true 6 | final_dim: 768 7 | latent_temp: [2.0,0.1,0.999995] 8 | encoder_layerdrop: 0.0 9 | dropout_input: 0.0 10 | dropout_features: 0.0 11 | dropout: 0.0 12 | attention_dropout: 0.0 13 | conv_bias: true 14 | 15 | encoder_layers: 24 16 | encoder_embed_dim: 1024 17 | encoder_ffn_embed_dim: 4096 18 | encoder_attention_heads: 16 19 | 20 | feature_grad_mult: 1.0 21 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | from fairseq.criterions.fairseq_criterion import ( # noqa 12 | FairseqCriterion, 13 | LegacyFairseqCriterion, 14 | ) 15 | from omegaconf import DictConfig 16 | 17 | 18 | ( 19 | build_criterion_, 20 | register_criterion, 21 | CRITERION_REGISTRY, 22 | CRITERION_DATACLASS_REGISTRY, 23 | ) = registry.setup_registry( 24 | "--criterion", base_class=FairseqCriterion, default="cross_entropy" 25 | ) 26 | 27 | 28 | def build_criterion(cfg: DictConfig, task): 29 | return build_criterion_(cfg, task) 30 | 31 | 32 | # automatically import any Python files in the criterions/ directory 33 | for file in sorted(os.listdir(os.path.dirname(__file__))): 34 | if file.endswith(".py") and not file.startswith("_"): 35 | file_name = file[: file.find(".py")] 36 | importlib.import_module("fairseq.criterions." + file_name) 37 | -------------------------------------------------------------------------------- /fairseq/data/append_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class AppendTokenDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | self.token = token 16 | if token is not None: 17 | self._sizes = np.array(dataset.sizes) + 1 18 | else: 19 | self._sizes = dataset.sizes 20 | 21 | def __getitem__(self, idx): 22 | item = self.dataset[idx] 23 | if self.token is not None: 24 | item = torch.cat([item, item.new([self.token])]) 25 | return item 26 | 27 | @property 28 | def sizes(self): 29 | return self._sizes 30 | 31 | def num_tokens(self, index): 32 | n = self.dataset.num_tokens(index) 33 | if self.token is not None: 34 | n += 1 35 | return n 36 | 37 | def size(self, index): 38 | n = self.dataset.size(index) 39 | if self.token is not None: 40 | n += 1 41 | return n 42 | -------------------------------------------------------------------------------- /fairseq/data/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FA-DAT/3599a3023c89ef16d2b33bf7c3a2b6f85aedaa4b/fairseq/data/audio/__init__.py -------------------------------------------------------------------------------- /fairseq/data/audio/feature_transforms/delta_deltas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from fairseq.data.audio.feature_transforms import ( 4 | AudioFeatureTransform, 5 | register_audio_feature_transform, 6 | ) 7 | 8 | 9 | @register_audio_feature_transform("delta_deltas") 10 | class DeltaDeltas(AudioFeatureTransform): 11 | """Expand delta-deltas features from spectrum.""" 12 | 13 | @classmethod 14 | def from_config_dict(cls, config=None): 15 | _config = {} if config is None else config 16 | return DeltaDeltas(_config.get("win_length", 5)) 17 | 18 | def __init__(self, win_length=5): 19 | self.win_length = win_length 20 | 21 | def __repr__(self): 22 | return self.__class__.__name__ 23 | 24 | def __call__(self, spectrogram): 25 | from torchaudio.functional import compute_deltas 26 | 27 | assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." 28 | # spectrogram is T x F, while compute_deltas takes (…, F, T) 29 | spectrogram = torch.from_numpy(spectrogram).transpose(0, 1) 30 | delta = compute_deltas(spectrogram) 31 | delta_delta = compute_deltas(delta) 32 | 33 | out_feat = np.concatenate( 34 | [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0 35 | ) 36 | out_feat = np.transpose(out_feat) 37 | return out_feat 38 | -------------------------------------------------------------------------------- /fairseq/data/audio/feature_transforms/global_cmvn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fairseq.data.audio.feature_transforms import ( 3 | AudioFeatureTransform, 4 | register_audio_feature_transform, 5 | ) 6 | 7 | 8 | @register_audio_feature_transform("global_cmvn") 9 | class GlobalCMVN(AudioFeatureTransform): 10 | """Global CMVN (cepstral mean and variance normalization). The global mean 11 | and variance need to be pre-computed and stored in NumPy format (.npz).""" 12 | 13 | @classmethod 14 | def from_config_dict(cls, config=None): 15 | _config = {} if config is None else config 16 | return GlobalCMVN(_config.get("stats_npz_path")) 17 | 18 | def __init__(self, stats_npz_path): 19 | self.stats_npz_path = stats_npz_path 20 | stats = np.load(stats_npz_path) 21 | self.mean, self.std = stats["mean"], stats["std"] 22 | 23 | def __repr__(self): 24 | return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")' 25 | 26 | def __call__(self, x): 27 | x = np.subtract(x, self.mean) 28 | x = np.divide(x, self.std) 29 | return x 30 | -------------------------------------------------------------------------------- /fairseq/data/audio/feature_transforms/utterance_cmvn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from fairseq.data.audio.feature_transforms import ( 4 | AudioFeatureTransform, 5 | register_audio_feature_transform, 6 | ) 7 | 8 | 9 | @register_audio_feature_transform("utterance_cmvn") 10 | class UtteranceCMVN(AudioFeatureTransform): 11 | """Utterance-level CMVN (cepstral mean and variance normalization)""" 12 | 13 | @classmethod 14 | def from_config_dict(cls, config=None): 15 | _config = {} if config is None else config 16 | return UtteranceCMVN( 17 | _config.get("norm_means", True), 18 | _config.get("norm_vars", True), 19 | ) 20 | 21 | def __init__(self, norm_means=True, norm_vars=True): 22 | self.norm_means, self.norm_vars = norm_means, norm_vars 23 | 24 | def __repr__(self): 25 | return ( 26 | self.__class__.__name__ 27 | + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})" 28 | ) 29 | 30 | def __call__(self, x): 31 | mean = x.mean(axis=0) 32 | square_sums = (x**2).sum(axis=0) 33 | 34 | if self.norm_means: 35 | x = np.subtract(x, mean) 36 | if self.norm_vars: 37 | var = square_sums / x.shape[0] - mean**2 38 | std = np.sqrt(np.maximum(var, 1e-10)) 39 | x = np.divide(x, std) 40 | 41 | return x 42 | -------------------------------------------------------------------------------- /fairseq/data/base_wrapper_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class BaseWrapperDataset(FairseqDataset): 12 | def __init__(self, dataset): 13 | super().__init__() 14 | self.dataset = dataset 15 | 16 | def __getitem__(self, index): 17 | return self.dataset[index] 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def collater(self, samples): 23 | if hasattr(self.dataset, "collater"): 24 | return self.dataset.collater(samples) 25 | else: 26 | return default_collate(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return self.dataset.sizes 31 | 32 | def num_tokens(self, index): 33 | return self.dataset.num_tokens(index) 34 | 35 | def size(self, index): 36 | return self.dataset.size(index) 37 | 38 | def ordered_indices(self): 39 | return self.dataset.ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return getattr(self.dataset, "supports_prefetch", False) 44 | 45 | def attr(self, attr: str, index: int): 46 | return self.dataset.attr(attr, index) 47 | 48 | def prefetch(self, indices): 49 | self.dataset.prefetch(indices) 50 | 51 | def get_batch_shapes(self): 52 | return self.dataset.get_batch_shapes() 53 | 54 | def batch_by_size( 55 | self, 56 | indices, 57 | max_tokens=None, 58 | max_sentences=None, 59 | required_batch_size_multiple=1, 60 | ): 61 | return self.dataset.batch_by_size( 62 | indices, 63 | max_tokens=max_tokens, 64 | max_sentences=max_sentences, 65 | required_batch_size_multiple=required_batch_size_multiple, 66 | ) 67 | 68 | def filter_indices_by_size(self, indices, max_sizes): 69 | return self.dataset.filter_indices_by_size(indices, max_sizes) 70 | 71 | @property 72 | def can_reuse_epoch_itr_across_epochs(self): 73 | return self.dataset.can_reuse_epoch_itr_across_epochs 74 | 75 | def set_epoch(self, epoch): 76 | super().set_epoch(epoch) 77 | if hasattr(self.dataset, "set_epoch"): 78 | self.dataset.set_epoch(epoch) 79 | -------------------------------------------------------------------------------- /fairseq/data/colorize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class ColorizeDataset(BaseWrapperDataset): 12 | """Adds 'colors' property to net input that is obtained from the provided color getter for use by models""" 13 | 14 | def __init__(self, dataset, color_getter): 15 | super().__init__(dataset) 16 | self.color_getter = color_getter 17 | 18 | def collater(self, samples): 19 | base_collate = super().collater(samples) 20 | if len(base_collate) > 0: 21 | base_collate["net_input"]["colors"] = torch.tensor( 22 | list(self.color_getter(self.dataset, s["id"]) for s in samples), 23 | dtype=torch.long, 24 | ) 25 | return base_collate 26 | -------------------------------------------------------------------------------- /fairseq/data/concat_sentences_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class ConcatSentencesDataset(FairseqDataset): 12 | def __init__(self, *datasets): 13 | super().__init__() 14 | self.datasets = datasets 15 | assert all( 16 | len(ds) == len(datasets[0]) for ds in datasets 17 | ), "datasets must have the same length" 18 | 19 | def __getitem__(self, index): 20 | return torch.cat([ds[index] for ds in self.datasets]) 21 | 22 | def __len__(self): 23 | return len(self.datasets[0]) 24 | 25 | def collater(self, samples): 26 | return self.datasets[0].collater(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return sum(ds.sizes for ds in self.datasets) 31 | 32 | def num_tokens(self, index): 33 | return sum(ds.num_tokens(index) for ds in self.datasets) 34 | 35 | def size(self, index): 36 | return sum(ds.size(index) for ds in self.datasets) 37 | 38 | def ordered_indices(self): 39 | return self.datasets[0].ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) 44 | 45 | def prefetch(self, indices): 46 | for ds in self.datasets: 47 | if getattr(ds, "supports_prefetch", False): 48 | ds.prefetch(indices) 49 | 50 | def set_epoch(self, epoch): 51 | super().set_epoch(epoch) 52 | for ds in self.datasets: 53 | if hasattr(ds, "set_epoch"): 54 | ds.set_epoch(epoch) 55 | -------------------------------------------------------------------------------- /fairseq/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( 14 | "--tokenizer", 15 | default=None, 16 | ) 17 | 18 | 19 | build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( 20 | "--bpe", 21 | default=None, 22 | ) 23 | 24 | 25 | # automatically import any Python files in the encoders/ directory 26 | for file in sorted(os.listdir(os.path.dirname(__file__))): 27 | if file.endswith(".py") and not file.startswith("_"): 28 | module = file[: file.find(".py")] 29 | importlib.import_module("fairseq.data.encoders." + module) 30 | -------------------------------------------------------------------------------- /fairseq/data/encoders/byte_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from dataclasses import dataclass, field 8 | 9 | from fairseq import file_utils 10 | from fairseq.data.encoders import register_bpe 11 | from fairseq.data.encoders.byte_utils import ( 12 | SPACE, 13 | SPACE_ESCAPE, 14 | byte_encode, 15 | smart_byte_decode, 16 | ) 17 | from fairseq.dataclass import FairseqDataclass 18 | 19 | 20 | @dataclass 21 | class ByteBpeConfig(FairseqDataclass): 22 | sentencepiece_model_path: str = field( 23 | default="???", metadata={"help": "path to sentencepiece model"} 24 | ) 25 | 26 | 27 | @register_bpe("byte_bpe", dataclass=ByteBpeConfig) 28 | class ByteBPE(object): 29 | def __init__(self, cfg): 30 | vocab = file_utils.cached_path(cfg.sentencepiece_model_path) 31 | try: 32 | import sentencepiece as spm 33 | 34 | self.sp = spm.SentencePieceProcessor() 35 | self.sp.Load(vocab) 36 | except ImportError: 37 | raise ImportError( 38 | "Please install sentencepiece with: pip install sentencepiece" 39 | ) 40 | 41 | def encode(self, x: str) -> str: 42 | byte_encoded = byte_encode(x) 43 | return SPACE.join(self.sp.EncodeAsPieces(byte_encoded)) 44 | 45 | @staticmethod 46 | def decode(x: str) -> str: 47 | unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) 48 | return smart_byte_decode(unescaped) 49 | -------------------------------------------------------------------------------- /fairseq/data/encoders/byte_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | 9 | WHITESPACE_NORMALIZER = re.compile(r"\s+") 10 | SPACE = chr(32) 11 | SPACE_ESCAPE = chr(9601) 12 | # excluding non-breaking space (160) here 13 | PRINTABLE_LATIN = set( 14 | list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1)) 15 | ) 16 | BYTE_TO_BCHAR = { 17 | b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256) 18 | } 19 | BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} 20 | 21 | 22 | def byte_encode(x: str) -> str: 23 | normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) 24 | return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) 25 | 26 | 27 | def byte_decode(x: str) -> str: 28 | try: 29 | return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") 30 | except ValueError: 31 | return "" 32 | 33 | 34 | def smart_byte_decode(x: str) -> str: 35 | output = byte_decode(x) 36 | if output == "": 37 | # DP the best recovery (max valid chars) if it's broken 38 | n_bytes = len(x) 39 | f = [0 for _ in range(n_bytes + 1)] 40 | pt = [0 for _ in range(n_bytes + 1)] 41 | for i in range(1, n_bytes + 1): 42 | f[i], pt[i] = f[i - 1], i - 1 43 | for j in range(1, min(4, i) + 1): 44 | if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: 45 | f[i], pt[i] = f[i - j] + 1, i - j 46 | cur_pt = n_bytes 47 | while cur_pt > 0: 48 | if f[cur_pt] == f[pt[cur_pt]] + 1: 49 | output = byte_decode(x[pt[cur_pt] : cur_pt]) + output 50 | cur_pt = pt[cur_pt] 51 | return output 52 | -------------------------------------------------------------------------------- /fairseq/data/encoders/bytes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq.data.encoders import register_bpe 8 | from fairseq.data.encoders.byte_utils import ( 9 | SPACE, 10 | SPACE_ESCAPE, 11 | byte_encode, 12 | smart_byte_decode, 13 | ) 14 | 15 | 16 | @register_bpe("bytes") 17 | class Bytes(object): 18 | def __init__(self, *unused): 19 | pass 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | pass 24 | 25 | @staticmethod 26 | def encode(x: str) -> str: 27 | encoded = byte_encode(x) 28 | escaped = encoded.replace(SPACE, SPACE_ESCAPE) 29 | return SPACE.join(list(escaped)) 30 | 31 | @staticmethod 32 | def decode(x: str) -> str: 33 | unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) 34 | return smart_byte_decode(unescaped) 35 | -------------------------------------------------------------------------------- /fairseq/data/encoders/characters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | SPACE = chr(32) 11 | SPACE_ESCAPE = chr(9601) 12 | 13 | 14 | @register_bpe("characters") 15 | class Characters(object): 16 | def __init__(self, *unused): 17 | pass 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | pass 22 | 23 | @staticmethod 24 | def encode(x: str) -> str: 25 | escaped = x.replace(SPACE, SPACE_ESCAPE) 26 | return SPACE.join(list(escaped)) 27 | 28 | @staticmethod 29 | def decode(x: str) -> str: 30 | return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) 31 | -------------------------------------------------------------------------------- /fairseq/data/encoders/fastbpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.encoders import register_bpe 10 | from fairseq.dataclass import FairseqDataclass 11 | 12 | 13 | @dataclass 14 | class fastBPEConfig(FairseqDataclass): 15 | bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"}) 16 | 17 | 18 | @register_bpe("fastbpe", dataclass=fastBPEConfig) 19 | class fastBPE(object): 20 | def __init__(self, cfg): 21 | if cfg.bpe_codes is None: 22 | raise ValueError("--bpe-codes is required for --bpe=fastbpe") 23 | codes = file_utils.cached_path(cfg.bpe_codes) 24 | try: 25 | import fastBPE 26 | 27 | self.bpe = fastBPE.fastBPE(codes) 28 | self.bpe_symbol = "@@ " 29 | except ImportError: 30 | raise ImportError("Please install fastBPE with: pip install fastBPE") 31 | 32 | def encode(self, x: str) -> str: 33 | return self.bpe.apply([x])[0] 34 | 35 | def decode(self, x: str) -> str: 36 | return (x + " ").replace(self.bpe_symbol, "").rstrip() 37 | -------------------------------------------------------------------------------- /fairseq/data/encoders/gpt2_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.encoders import register_bpe 10 | from fairseq.dataclass import FairseqDataclass 11 | 12 | from .gpt2_bpe_utils import get_encoder 13 | 14 | 15 | DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" 16 | DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" 17 | 18 | 19 | @dataclass 20 | class GPT2BPEConfig(FairseqDataclass): 21 | gpt2_encoder_json: str = field( 22 | default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} 23 | ) 24 | gpt2_vocab_bpe: str = field( 25 | default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} 26 | ) 27 | 28 | 29 | @register_bpe("gpt2", dataclass=GPT2BPEConfig) 30 | class GPT2BPE(object): 31 | def __init__(self, cfg): 32 | encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) 33 | vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) 34 | self.bpe = get_encoder(encoder_json, vocab_bpe) 35 | 36 | def encode(self, x: str) -> str: 37 | return " ".join(map(str, self.bpe.encode(x))) 38 | 39 | def decode(self, x: str) -> str: 40 | return self.bpe.decode( 41 | [int(tok) if tok not in {"", ""} else tok for tok in x.split()] 42 | ) 43 | 44 | def is_beginning_of_word(self, x: str) -> bool: 45 | return self.decode(x).startswith(" ") 46 | -------------------------------------------------------------------------------- /fairseq/data/encoders/hf_bert_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | from typing import Optional 8 | 9 | from fairseq.data.encoders import register_bpe 10 | from fairseq.dataclass import FairseqDataclass 11 | 12 | 13 | @dataclass 14 | class BertBPEConfig(FairseqDataclass): 15 | bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"}) 16 | bpe_vocab_file: Optional[str] = field( 17 | default=None, metadata={"help": "bpe vocab file"} 18 | ) 19 | 20 | 21 | @register_bpe("bert", dataclass=BertBPEConfig) 22 | class BertBPE(object): 23 | def __init__(self, cfg): 24 | try: 25 | from transformers import BertTokenizer 26 | except ImportError: 27 | raise ImportError( 28 | "Please install transformers with: pip install transformers" 29 | ) 30 | 31 | if cfg.bpe_vocab_file: 32 | self.bert_tokenizer = BertTokenizer( 33 | cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased 34 | ) 35 | else: 36 | vocab_file_name = ( 37 | "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased" 38 | ) 39 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) 40 | 41 | def encode(self, x: str) -> str: 42 | return " ".join(self.bert_tokenizer.tokenize(x)) 43 | 44 | def decode(self, x: str) -> str: 45 | return self.bert_tokenizer.clean_up_tokenization( 46 | self.bert_tokenizer.convert_tokens_to_string(x.split(" ")) 47 | ) 48 | 49 | def is_beginning_of_word(self, x: str) -> bool: 50 | return not x.startswith("##") 51 | -------------------------------------------------------------------------------- /fairseq/data/encoders/hf_byte_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from fairseq.data.encoders import register_bpe 9 | from fairseq.dataclass import FairseqDataclass 10 | from fairseq import file_utils 11 | 12 | 13 | @dataclass 14 | class HuggingFaceByteLevelBPEConfig(FairseqDataclass): 15 | bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"}) 16 | bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"}) 17 | bpe_add_prefix_space: bool = field( 18 | default=False, metadata={"help": "add prefix space before encoding"} 19 | ) 20 | 21 | 22 | @register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig) 23 | class HuggingFaceByteLevelBPE(object): 24 | def __init__(self, cfg): 25 | try: 26 | from tokenizers import ByteLevelBPETokenizer 27 | except ImportError: 28 | raise ImportError( 29 | "Please install huggingface/tokenizers with: " "pip install tokenizers" 30 | ) 31 | 32 | bpe_vocab = file_utils.cached_path(cfg.bpe_vocab) 33 | bpe_merges = file_utils.cached_path(cfg.bpe_merges) 34 | 35 | self.bpe = ByteLevelBPETokenizer( 36 | bpe_vocab, 37 | bpe_merges, 38 | add_prefix_space=cfg.bpe_add_prefix_space, 39 | ) 40 | 41 | def encode(self, x: str) -> str: 42 | return " ".join(map(str, self.bpe.encode(x).ids)) 43 | 44 | def decode(self, x: str) -> str: 45 | return self.bpe.decode( 46 | [int(tok) if tok not in {"", ""} else tok for tok in x.split()] 47 | ) 48 | 49 | def is_beginning_of_word(self, x: str) -> bool: 50 | return self.decode(x).startswith(" ") 51 | -------------------------------------------------------------------------------- /fairseq/data/encoders/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | from fairseq.dataclass import FairseqDataclass 10 | 11 | 12 | @dataclass 13 | class MosesTokenizerConfig(FairseqDataclass): 14 | source_lang: str = field(default="en", metadata={"help": "source language"}) 15 | target_lang: str = field(default="en", metadata={"help": "target language"}) 16 | moses_no_dash_splits: bool = field( 17 | default=False, metadata={"help": "don't apply dash split rules"} 18 | ) 19 | moses_no_escape: bool = field( 20 | default=False, 21 | metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, 22 | ) 23 | 24 | 25 | @register_tokenizer("moses", dataclass=MosesTokenizerConfig) 26 | class MosesTokenizer(object): 27 | def __init__(self, cfg: MosesTokenizerConfig): 28 | self.cfg = cfg 29 | 30 | try: 31 | from sacremoses import MosesTokenizer, MosesDetokenizer 32 | 33 | self.tok = MosesTokenizer(cfg.source_lang) 34 | self.detok = MosesDetokenizer(cfg.target_lang) 35 | except ImportError: 36 | raise ImportError( 37 | "Please install Moses tokenizer with: pip install sacremoses" 38 | ) 39 | 40 | def encode(self, x: str) -> str: 41 | return self.tok.tokenize( 42 | x, 43 | aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), 44 | return_str=True, 45 | escape=(not self.cfg.moses_no_escape), 46 | ) 47 | 48 | def decode(self, x: str) -> str: 49 | return self.detok.detokenize(x.split()) 50 | -------------------------------------------------------------------------------- /fairseq/data/encoders/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | from fairseq.dataclass import FairseqDataclass 8 | 9 | 10 | @register_tokenizer("nltk", dataclass=FairseqDataclass) 11 | class NLTKTokenizer(object): 12 | def __init__(self, *unused): 13 | try: 14 | from nltk.tokenize import word_tokenize 15 | 16 | self.word_tokenize = word_tokenize 17 | except ImportError: 18 | raise ImportError("Please install nltk with: pip install nltk") 19 | 20 | def encode(self, x: str) -> str: 21 | return " ".join(self.word_tokenize(x)) 22 | 23 | def decode(self, x: str) -> str: 24 | return x 25 | -------------------------------------------------------------------------------- /fairseq/data/encoders/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | from fairseq.dataclass import FairseqDataclass 10 | 11 | 12 | @register_tokenizer("space", dataclass=FairseqDataclass) 13 | class SpaceTokenizer(object): 14 | def __init__(self, *unused): 15 | self.space_tok = re.compile(r"\s+") 16 | 17 | def encode(self, x: str) -> str: 18 | return self.space_tok.sub(" ", x) 19 | 20 | def decode(self, x: str) -> str: 21 | return x 22 | -------------------------------------------------------------------------------- /fairseq/data/encoders/subword_nmt_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from fairseq import file_utils 9 | from fairseq.data.encoders import register_bpe 10 | from fairseq.dataclass import FairseqDataclass 11 | 12 | 13 | @dataclass 14 | class SubwordNMTBPEConfig(FairseqDataclass): 15 | bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"}) 16 | bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"}) 17 | 18 | 19 | @register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig) 20 | class SubwordNMTBPE(object): 21 | def __init__(self, cfg): 22 | if cfg.bpe_codes is None: 23 | raise ValueError("--bpe-codes is required for --bpe=subword_nmt") 24 | codes = file_utils.cached_path(cfg.bpe_codes) 25 | try: 26 | from subword_nmt import apply_bpe 27 | 28 | bpe_parser = apply_bpe.create_parser() 29 | bpe_args = bpe_parser.parse_args( 30 | [ 31 | "--codes", 32 | codes, 33 | "--separator", 34 | cfg.bpe_separator, 35 | ] 36 | ) 37 | self.bpe = apply_bpe.BPE( 38 | bpe_args.codes, 39 | bpe_args.merges, 40 | bpe_args.separator, 41 | None, 42 | bpe_args.glossaries, 43 | ) 44 | self.bpe_symbol = bpe_args.separator + " " 45 | except ImportError: 46 | raise ImportError( 47 | "Please install subword_nmt with: pip install subword-nmt" 48 | ) 49 | 50 | def encode(self, x: str) -> str: 51 | return self.bpe.process_line(x) 52 | 53 | def decode(self, x: str) -> str: 54 | return (x + " ").replace(self.bpe_symbol, "").rstrip() 55 | -------------------------------------------------------------------------------- /fairseq/data/encoders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from fairseq.data import encoders 8 | 9 | 10 | def get_whole_word_mask(args, dictionary): 11 | bpe = encoders.build_bpe(args) 12 | if bpe is not None: 13 | 14 | def is_beginning_of_word(i): 15 | if i < dictionary.nspecial: 16 | # special elements are always considered beginnings 17 | return True 18 | tok = dictionary[i] 19 | if tok.startswith("madeupword"): 20 | return True 21 | try: 22 | return bpe.is_beginning_of_word(tok) 23 | except ValueError: 24 | return True 25 | 26 | mask_whole_words = torch.ByteTensor( 27 | list(map(is_beginning_of_word, range(len(dictionary)))) 28 | ) 29 | return mask_whole_words 30 | return None 31 | -------------------------------------------------------------------------------- /fairseq/data/huffman/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder 7 | from .huffman_mmap_indexed_dataset import ( 8 | HuffmanMMapIndex, 9 | HuffmanMMapIndexedDataset, 10 | HuffmanMMapIndexedDatasetBuilder, 11 | vocab_file_path, 12 | ) 13 | 14 | __all__ = [ 15 | "HuffmanCoder", 16 | "HuffmanCodeBuilder", 17 | "HuffmanMMapIndexedDatasetBuilder", 18 | "HuffmanMMapIndexedDataset", 19 | "HuffmanMMapIndex", 20 | "vocab_file_path", 21 | ] 22 | -------------------------------------------------------------------------------- /fairseq/data/id_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class IdDataset(FairseqDataset): 12 | def __getitem__(self, index): 13 | return index 14 | 15 | def __len__(self): 16 | return 0 17 | 18 | def collater(self, samples): 19 | return torch.tensor(samples) 20 | -------------------------------------------------------------------------------- /fairseq/data/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .block_pair_dataset import BlockPairDataset 7 | from .masked_lm_dataset import MaskedLMDataset 8 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary 9 | 10 | 11 | __all__ = [ 12 | "BertDictionary", 13 | "BlockPairDataset", 14 | "MaskedLMDataset", 15 | "MaskedLMDictionary", 16 | ] 17 | -------------------------------------------------------------------------------- /fairseq/data/legacy/masked_lm_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import Dictionary 7 | 8 | 9 | class MaskedLMDictionary(Dictionary): 10 | """ 11 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by 12 | adding the mask symbol. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | pad="", 18 | eos="", 19 | unk="", 20 | mask="", 21 | ): 22 | super().__init__(pad=pad, eos=eos, unk=unk) 23 | self.mask_word = mask 24 | self.mask_index = self.add_symbol(mask) 25 | self.nspecial = len(self.symbols) 26 | 27 | def mask(self): 28 | """Helper to get index of mask symbol""" 29 | return self.mask_index 30 | 31 | 32 | class BertDictionary(MaskedLMDictionary): 33 | """ 34 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support 35 | for cls and sep symbols. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | pad="", 41 | eos="", 42 | unk="", 43 | mask="", 44 | cls="", 45 | sep="", 46 | ): 47 | super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) 48 | self.cls_word = cls 49 | self.sep_word = sep 50 | self.cls_index = self.add_symbol(cls) 51 | self.sep_index = self.add_symbol(sep) 52 | self.nspecial = len(self.symbols) 53 | 54 | def cls(self): 55 | """Helper to get index of cls symbol""" 56 | return self.cls_index 57 | 58 | def sep(self): 59 | """Helper to get index of sep symbol""" 60 | return self.sep_index 61 | -------------------------------------------------------------------------------- /fairseq/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ListDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, sizes=None): 11 | super().__init__(dataset) 12 | self._sizes = sizes 13 | 14 | def __iter__(self): 15 | for x in self.dataset: 16 | yield x 17 | 18 | def collater(self, samples): 19 | return samples 20 | 21 | @property 22 | def sizes(self): 23 | return self._sizes 24 | 25 | def num_tokens(self, index): 26 | return self.sizes[index] 27 | 28 | def size(self, index): 29 | return self.sizes[index] 30 | 31 | def set_epoch(self, epoch): 32 | pass 33 | -------------------------------------------------------------------------------- /fairseq/data/lru_cache_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import lru_cache 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class LRUCacheDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, token=None): 13 | super().__init__(dataset) 14 | 15 | @lru_cache(maxsize=8) 16 | def __getitem__(self, index): 17 | return self.dataset[index] 18 | 19 | @lru_cache(maxsize=8) 20 | def collater(self, samples): 21 | return self.dataset.collater(samples) 22 | -------------------------------------------------------------------------------- /fairseq/data/multilingual/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /fairseq/data/multilingual/multilingual_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict, List, Optional, Sequence 3 | 4 | import torch 5 | from fairseq.data import Dictionary 6 | 7 | 8 | class EncoderLangtok(Enum): 9 | """ 10 | Prepend to the beginning of source sentence either the 11 | source or target language token. (src/tgt). 12 | """ 13 | 14 | src = "src" 15 | tgt = "tgt" 16 | 17 | 18 | class LangTokSpec(Enum): 19 | main = "main" 20 | mono_dae = "mono_dae" 21 | 22 | 23 | class LangTokStyle(Enum): 24 | multilingual = "multilingual" 25 | mbart = "mbart" 26 | 27 | 28 | @torch.jit.export 29 | def get_lang_tok( 30 | lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value 31 | ) -> str: 32 | # TOKEN_STYLES can't be defined outside this fn since it needs to be 33 | # TorchScriptable. 34 | TOKEN_STYLES: Dict[str, str] = { 35 | LangTokStyle.mbart.value: "[{}]", 36 | LangTokStyle.multilingual.value: "__{}__", 37 | } 38 | 39 | if spec.endswith("dae"): 40 | lang = f"{lang}_dae" 41 | elif spec.endswith("mined"): 42 | lang = f"{lang}_mined" 43 | style = TOKEN_STYLES[lang_tok_style] 44 | return style.format(lang) 45 | 46 | 47 | def augment_dictionary( 48 | dictionary: Dictionary, 49 | language_list: List[str], 50 | lang_tok_style: str, 51 | langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), 52 | extra_data: Optional[Dict[str, str]] = None, 53 | ) -> None: 54 | for spec in langtoks_specs: 55 | for language in language_list: 56 | dictionary.add_symbol( 57 | get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) 58 | ) 59 | 60 | if lang_tok_style == LangTokStyle.mbart.value or ( 61 | extra_data is not None and LangTokSpec.mono_dae.value in extra_data 62 | ): 63 | dictionary.add_symbol("") 64 | -------------------------------------------------------------------------------- /fairseq/data/multilingual/sampling_method.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def uniform(dataset_sizes: List[int]): 14 | return [1.0] * len(dataset_sizes) 15 | 16 | 17 | def temperature_sampling(dataset_sizes, temp): 18 | total_size = sum(dataset_sizes) 19 | return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] 20 | 21 | 22 | def make_temperature_sampling(temp=1.0): 23 | def sampling_func(dataset_sizes): 24 | return temperature_sampling(dataset_sizes, temp) 25 | 26 | return sampling_func 27 | 28 | 29 | def make_ratio_sampling(ratios): 30 | def sampling_func(dataset_sizes): 31 | return ratios 32 | 33 | return sampling_func 34 | 35 | 36 | class SamplingMethod: 37 | @staticmethod 38 | def add_arguments(parser): 39 | parser.add_argument( 40 | "--sampling-method", 41 | choices=[ 42 | "uniform", 43 | "temperature", 44 | "concat", 45 | "RoundRobin", 46 | ], 47 | type=str, 48 | default="concat", 49 | help="The method to sample data per language pairs", 50 | ) 51 | parser.add_argument( 52 | "--sampling-temperature", 53 | default=1.5, 54 | type=float, 55 | help="only work with --sampling-method temperature", 56 | ) 57 | 58 | @staticmethod 59 | def build_sampler(args, task): 60 | return SamplingMethod(args, task) 61 | 62 | def __init__(self, args, task): 63 | self.args = args 64 | self.task = task 65 | 66 | def is_adaptive(self): 67 | return False 68 | 69 | def sampling_method_selector(self): 70 | args = self.args 71 | logger.info(f"selected sampler: {args.sampling_method}") 72 | if args.sampling_method == "uniform": 73 | return uniform 74 | elif args.sampling_method == "temperature" or self.is_adaptive(): 75 | return make_temperature_sampling(float(args.sampling_temperature)) 76 | else: 77 | # default to concating all data set together 78 | return None 79 | -------------------------------------------------------------------------------- /fairseq/data/num_samples_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqDataset 7 | 8 | 9 | class NumSamplesDataset(FairseqDataset): 10 | def __getitem__(self, index): 11 | return 1 12 | 13 | def __len__(self): 14 | return 0 15 | 16 | def collater(self, samples): 17 | return sum(samples) 18 | -------------------------------------------------------------------------------- /fairseq/data/numel_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class NumelDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, reduce=False): 14 | super().__init__(dataset) 15 | self.reduce = reduce 16 | 17 | def __getitem__(self, index): 18 | item = self.dataset[index] 19 | if torch.is_tensor(item): 20 | return torch.numel(item) 21 | else: 22 | return np.size(item) 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def collater(self, samples): 28 | if self.reduce: 29 | return sum(samples) 30 | else: 31 | return torch.tensor(samples) 32 | -------------------------------------------------------------------------------- /fairseq/data/offset_tokens_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class OffsetTokensDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, offset): 11 | super().__init__(dataset) 12 | self.offset = offset 13 | 14 | def __getitem__(self, idx): 15 | return self.dataset[idx] + self.offset 16 | -------------------------------------------------------------------------------- /fairseq/data/pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import data_utils 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class PadDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, pad_idx, left_pad, pad_length=None): 13 | super().__init__(dataset) 14 | self.pad_idx = pad_idx 15 | self.left_pad = left_pad 16 | self.pad_length = pad_length 17 | 18 | def collater(self, samples): 19 | return data_utils.collate_tokens( 20 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length 21 | ) 22 | 23 | 24 | class LeftPadDataset(PadDataset): 25 | def __init__(self, dataset, pad_idx): 26 | super().__init__(dataset, pad_idx, left_pad=True) 27 | 28 | 29 | class RightPadDataset(PadDataset): 30 | def __init__(self, dataset, pad_idx): 31 | super().__init__(dataset, pad_idx, left_pad=False) 32 | -------------------------------------------------------------------------------- /fairseq/data/prepend_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): 14 | super().__init__(dataset) 15 | self.prepend_getter = prepend_getter 16 | self.ensure_first_token = ensure_first_token_is 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataset[idx] 20 | is_tuple = isinstance(item, tuple) 21 | src = item[0] if is_tuple else item 22 | 23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token 24 | prepend_idx = self.prepend_getter(self.dataset, idx) 25 | assert isinstance(prepend_idx, int) 26 | src[0] = prepend_idx 27 | item = tuple((src,) + item[1:]) if is_tuple else src 28 | return item 29 | -------------------------------------------------------------------------------- /fairseq/data/prepend_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependTokenDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | self.token = token 16 | if token is not None: 17 | self._sizes = np.array(dataset.sizes) + 1 18 | else: 19 | self._sizes = dataset.sizes 20 | 21 | def __getitem__(self, idx): 22 | item = self.dataset[idx] 23 | if self.token is not None: 24 | item = torch.cat([item.new([self.token]), item]) 25 | return item 26 | 27 | @property 28 | def sizes(self): 29 | return self._sizes 30 | 31 | def num_tokens(self, index): 32 | n = self.dataset.num_tokens(index) 33 | if self.token is not None: 34 | n += 1 35 | return n 36 | 37 | def size(self, index): 38 | n = self.dataset.size(index) 39 | if self.token is not None: 40 | n += 1 41 | return n 42 | -------------------------------------------------------------------------------- /fairseq/data/raw_label_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class RawLabelDataset(FairseqDataset): 12 | def __init__(self, labels): 13 | super().__init__() 14 | self.labels = labels 15 | 16 | def __getitem__(self, index): 17 | return self.labels[index] 18 | 19 | def __len__(self): 20 | return len(self.labels) 21 | 22 | def collater(self, samples): 23 | return torch.tensor(samples) 24 | -------------------------------------------------------------------------------- /fairseq/data/replace_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ReplaceDataset(BaseWrapperDataset): 10 | """Replaces tokens found in the dataset by a specified replacement token 11 | 12 | Args: 13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in 14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token 15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be 16 | as many as the number of objects returned by the underlying dataset __getitem__ method. 17 | """ 18 | 19 | def __init__(self, dataset, replace_map, offsets): 20 | super().__init__(dataset) 21 | assert len(replace_map) > 0 22 | self.replace_map = replace_map 23 | self.offsets = offsets 24 | 25 | def __getitem__(self, index): 26 | item = self.dataset[index] 27 | is_tuple = isinstance(item, tuple) 28 | srcs = item if is_tuple else [item] 29 | 30 | for offset, src in zip(self.offsets, srcs): 31 | for k, v in self.replace_map.items(): 32 | src_off = src[offset:] if offset >= 0 else src[:offset] 33 | src_off.masked_fill_(src_off == k, v) 34 | 35 | item = srcs if is_tuple else srcs[0] 36 | return item 37 | -------------------------------------------------------------------------------- /fairseq/data/roll_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class RollDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, shifts): 13 | super().__init__(dataset) 14 | self.shifts = shifts 15 | 16 | def __getitem__(self, index): 17 | item = self.dataset[index] 18 | return torch.roll(item, self.shifts) 19 | -------------------------------------------------------------------------------- /fairseq/data/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class SortDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, sort_order): 13 | super().__init__(dataset) 14 | if not isinstance(sort_order, (list, tuple)): 15 | sort_order = [sort_order] 16 | self.sort_order = sort_order 17 | 18 | assert all(len(so) == len(dataset) for so in sort_order) 19 | 20 | def ordered_indices(self): 21 | return np.lexsort(self.sort_order) 22 | -------------------------------------------------------------------------------- /fairseq/data/strip_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class StripTokenDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, id_to_strip): 11 | super().__init__(dataset) 12 | self.id_to_strip = id_to_strip 13 | 14 | def __getitem__(self, index): 15 | item = self.dataset[index] 16 | while len(item) > 0 and item[-1] == self.id_to_strip: 17 | item = item[:-1] 18 | while len(item) > 0 and item[0] == self.id_to_strip: 19 | item = item[1:] 20 | return item 21 | -------------------------------------------------------------------------------- /fairseq/data/subsample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | import numpy as np 9 | 10 | from . import BaseWrapperDataset 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class SubsampleDataset(BaseWrapperDataset): 17 | """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples 18 | 19 | Args: 20 | dataset (~torch.utils.data.Dataset): dataset to subsample 21 | size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) 22 | """ 23 | 24 | def __init__(self, dataset, size_ratio, shuffle=False): 25 | super().__init__(dataset) 26 | assert size_ratio < 1 27 | self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) 28 | self.indices = np.random.choice( 29 | list(range(len(self.dataset))), self.actual_size, replace=False 30 | ) 31 | self.shuffle = shuffle 32 | logger.info( 33 | "subsampled dataset from {} to {} (ratio={})".format( 34 | len(self.dataset), self.actual_size, size_ratio 35 | ) 36 | ) 37 | 38 | def __getitem__(self, index): 39 | return self.dataset[self.indices[index]] 40 | 41 | def __len__(self): 42 | return self.actual_size 43 | 44 | def collater(self, samples): 45 | return self.dataset.collater(samples) 46 | 47 | @property 48 | def sizes(self): 49 | return self.dataset.sizes[self.indices] 50 | 51 | @property 52 | def name(self): 53 | return self.dataset.name 54 | 55 | def num_tokens(self, index): 56 | return self.dataset.num_tokens(self.indices[index]) 57 | 58 | def size(self, index): 59 | return self.dataset.size(self.indices[index]) 60 | 61 | def ordered_indices(self): 62 | """Return an ordered list of indices. Batches will be constructed based 63 | on this order.""" 64 | if self.shuffle: 65 | order = [np.random.permutation(len(self))] 66 | else: 67 | order = [np.arange(len(self))] 68 | order.append(self.sizes) 69 | return np.lexsort(order) 70 | 71 | def prefetch(self, indices): 72 | self.dataset.prefetch(self.indices[indices]) 73 | -------------------------------------------------------------------------------- /fairseq/data/text_compressor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | 8 | 9 | class TextCompressionLevel(Enum): 10 | none = 0 11 | low = 1 12 | high = 2 13 | 14 | 15 | class TextCompressor(object): 16 | def __init__( 17 | self, level: TextCompressionLevel, max_input_byte_length: int = 2**16 18 | ): 19 | self.level = level 20 | self.max_input_length = max_input_byte_length 21 | 22 | def compress(self, text: str) -> bytes: 23 | if self.level == TextCompressionLevel.low: 24 | import zlib 25 | 26 | # zlib: built-in, fast 27 | return zlib.compress(text.encode(), level=0) 28 | elif self.level == TextCompressionLevel.high: 29 | try: 30 | import unishox2 31 | 32 | # unishox2: optimized for short text but slower 33 | except ImportError: 34 | raise ImportError( 35 | "Please install unishox2 for the text compression feature: " 36 | "pip install unishox2-py3" 37 | ) 38 | assert len(text.encode()) <= self.max_input_length 39 | return unishox2.compress(text)[0] 40 | else: 41 | return text.encode() 42 | 43 | def decompress(self, compressed: bytes) -> str: 44 | if self.level == TextCompressionLevel.low: 45 | import zlib 46 | 47 | return zlib.decompress(compressed).decode() 48 | elif self.level == TextCompressionLevel.high: 49 | try: 50 | import unishox2 51 | except ImportError: 52 | raise ImportError( 53 | "Please install unishox2 for the text compression feature: " 54 | "pip install unishox2-py3" 55 | ) 56 | return unishox2.decompress(compressed, self.max_input_length) 57 | else: 58 | return compressed.decode() 59 | -------------------------------------------------------------------------------- /fairseq/dataclass/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .configs import FairseqDataclass 7 | from .constants import ChoiceEnum 8 | 9 | 10 | __all__ = [ 11 | "FairseqDataclass", 12 | "ChoiceEnum", 13 | ] 14 | -------------------------------------------------------------------------------- /fairseq/dataclass/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum, EnumMeta 7 | from typing import List 8 | 9 | 10 | class StrEnumMeta(EnumMeta): 11 | # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see 12 | # https://github.com/facebookresearch/hydra/issues/1156 13 | @classmethod 14 | def __instancecheck__(cls, other): 15 | return "enum" in str(type(other)) 16 | 17 | 18 | class StrEnum(Enum, metaclass=StrEnumMeta): 19 | def __str__(self): 20 | return self.value 21 | 22 | def __eq__(self, other: str): 23 | return self.value == other 24 | 25 | def __repr__(self): 26 | return self.value 27 | 28 | def __hash__(self): 29 | return hash(str(self)) 30 | 31 | 32 | def ChoiceEnum(choices: List[str]): 33 | """return the Enum class used to enforce list of choices""" 34 | return StrEnum("Choices", {k: k for k in choices}) 35 | 36 | 37 | LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) 38 | DDP_BACKEND_CHOICES = ChoiceEnum( 39 | [ 40 | "c10d", # alias for pytorch_ddp 41 | "fully_sharded", # FullyShardedDataParallel from fairscale 42 | "legacy_ddp", 43 | "no_c10d", # alias for legacy_ddp 44 | "pytorch_ddp", 45 | "slowmo", 46 | ] 47 | ) 48 | DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) 49 | DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) 50 | GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) 51 | GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( 52 | ["unigram", "ensemble", "vote", "dp", "bs"] 53 | ) 54 | ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) 55 | PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) 56 | PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) 57 | -------------------------------------------------------------------------------- /fairseq/dataclass/initialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import logging 8 | from hydra.core.config_store import ConfigStore 9 | from fairseq.dataclass.configs import FairseqConfig 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def hydra_init(cfg_name="config") -> None: 17 | 18 | cs = ConfigStore.instance() 19 | cs.store(name=f"{cfg_name}", node=FairseqConfig) 20 | 21 | for k in FairseqConfig.__dataclass_fields__: 22 | v = FairseqConfig.__dataclass_fields__[k].default 23 | try: 24 | cs.store(name=k, node=v) 25 | except BaseException: 26 | logger.error(f"{k} - {v}") 27 | raise 28 | 29 | 30 | def add_defaults(cfg: DictConfig) -> None: 31 | """This function adds default values that are stored in dataclasses that hydra doesn't know about""" 32 | 33 | from fairseq.registry import REGISTRIES 34 | from fairseq.tasks import TASK_DATACLASS_REGISTRY 35 | from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY 36 | from fairseq.dataclass.utils import merge_with_parent 37 | from typing import Any 38 | 39 | OmegaConf.set_struct(cfg, False) 40 | 41 | for k, v in FairseqConfig.__dataclass_fields__.items(): 42 | field_cfg = cfg.get(k) 43 | if field_cfg is not None and v.type == Any: 44 | dc = None 45 | 46 | if isinstance(field_cfg, str): 47 | field_cfg = DictConfig({"_name": field_cfg}) 48 | field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"] 49 | 50 | name = getattr(field_cfg, "_name", None) 51 | 52 | if k == "task": 53 | dc = TASK_DATACLASS_REGISTRY.get(name) 54 | elif k == "model": 55 | name = ARCH_MODEL_NAME_REGISTRY.get(name, name) 56 | dc = MODEL_DATACLASS_REGISTRY.get(name) 57 | elif k in REGISTRIES: 58 | dc = REGISTRIES[k]["dataclass_registry"].get(name) 59 | 60 | if dc is not None: 61 | cfg[k] = merge_with_parent(dc, field_cfg) 62 | -------------------------------------------------------------------------------- /fairseq/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .distributed_timeout_wrapper import DistributedTimeoutWrapper 7 | from .fully_sharded_data_parallel import ( 8 | fsdp_enable_wrap, 9 | fsdp_wrap, 10 | FullyShardedDataParallel, 11 | ) 12 | from .legacy_distributed_data_parallel import LegacyDistributedDataParallel 13 | from .module_proxy_wrapper import ModuleProxyWrapper 14 | from .tpu_distributed_data_parallel import TPUDistributedDataParallel 15 | 16 | 17 | __all__ = [ 18 | "DistributedTimeoutWrapper", 19 | "fsdp_enable_wrap", 20 | "fsdp_wrap", 21 | "FullyShardedDataParallel", 22 | "LegacyDistributedDataParallel", 23 | "ModuleProxyWrapper", 24 | "TPUDistributedDataParallel", 25 | ] 26 | -------------------------------------------------------------------------------- /fairseq/distributed/module_proxy_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch import nn 7 | 8 | 9 | class ModuleProxyWrapper(nn.Module): 10 | """ 11 | Wrap a DistributedDataParallel module and forward requests for missing 12 | attributes to the module wrapped by DDP (the twice-wrapped module). 13 | Also forward calls to :func:`state_dict` and :func:`load_state_dict`. 14 | 15 | Usage:: 16 | 17 | module.xyz = "hello world" 18 | wrapped_module = DistributedDataParallel(module, **ddp_args) 19 | wrapped_module = ModuleProxyWrapper(wrapped_module) 20 | assert wrapped_module.xyz == "hello world" 21 | assert wrapped_module.state_dict().keys() == module.state_dict().keys() 22 | 23 | Args: 24 | module (nn.Module): module to wrap 25 | """ 26 | 27 | def __init__(self, module: nn.Module): 28 | super().__init__() 29 | assert hasattr( 30 | module, "module" 31 | ), "ModuleProxyWrapper expects input to wrap another module" 32 | self.module = module 33 | 34 | def __getattr__(self, name): 35 | """Forward missing attributes to twice-wrapped module.""" 36 | try: 37 | # defer to nn.Module's logic 38 | return super().__getattr__(name) 39 | except AttributeError: 40 | try: 41 | # forward to the once-wrapped module 42 | return getattr(self.module, name) 43 | except AttributeError: 44 | # forward to the twice-wrapped module 45 | return getattr(self.module.module, name) 46 | 47 | def state_dict(self, *args, **kwargs): 48 | """Forward to the twice-wrapped module.""" 49 | return self.module.module.state_dict(*args, **kwargs) 50 | 51 | def load_state_dict(self, *args, **kwargs): 52 | """Forward to the twice-wrapped module.""" 53 | return self.module.module.load_state_dict(*args, **kwargs) 54 | 55 | def forward(self, *args, **kwargs): 56 | return self.module(*args, **kwargs) 57 | -------------------------------------------------------------------------------- /fairseq/distributed/tpu_distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from fairseq.distributed import utils 10 | 11 | 12 | class TPUDistributedDataParallel(nn.Module): 13 | def __init__(self, module, process_group): 14 | super().__init__() 15 | self.module = module 16 | self.process_group = process_group 17 | self.world_size = utils.get_world_size(self.process_group) 18 | 19 | def forward(self, *inputs, **kwargs): 20 | return self.module(*inputs, **kwargs) 21 | 22 | def all_reduce_grads(self): 23 | gradients = [] 24 | for p in self.parameters(): 25 | if not p.requires_grad: 26 | continue 27 | if p.grad is None: 28 | p.grad = torch.zeros_like(p) 29 | if p.grad.requires_grad: 30 | raise RuntimeError( 31 | "TPUDistributedDataParallel only works with gradients that don't " 32 | "require grad" 33 | ) 34 | gradients.append(p.grad) 35 | 36 | import torch_xla.core.xla_model as xm 37 | 38 | xm.all_reduce( 39 | "sum", 40 | gradients, 41 | scale=1.0 / self.world_size, 42 | groups=self.process_group[1], 43 | ) 44 | -------------------------------------------------------------------------------- /fairseq/incremental_decoding_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import uuid 7 | from typing import Dict, Optional 8 | 9 | from torch import Tensor 10 | 11 | 12 | class FairseqIncrementalState(object): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.init_incremental_state() 16 | 17 | def init_incremental_state(self): 18 | self._incremental_state_id = str(uuid.uuid4()) 19 | 20 | def _get_full_incremental_state_key(self, key: str) -> str: 21 | return "{}.{}".format(self._incremental_state_id, key) 22 | 23 | def get_incremental_state( 24 | self, 25 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 26 | key: str, 27 | ) -> Optional[Dict[str, Optional[Tensor]]]: 28 | """Helper for getting incremental state for an nn.Module.""" 29 | full_key = self._get_full_incremental_state_key(key) 30 | if incremental_state is None or full_key not in incremental_state: 31 | return None 32 | return incremental_state[full_key] 33 | 34 | def set_incremental_state( 35 | self, 36 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 37 | key: str, 38 | value: Dict[str, Optional[Tensor]], 39 | ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: 40 | """Helper for setting incremental state for an nn.Module.""" 41 | if incremental_state is not None: 42 | full_key = self._get_full_incremental_state_key(key) 43 | incremental_state[full_key] = value 44 | return incremental_state 45 | 46 | 47 | def with_incremental_state(cls): 48 | cls.__bases__ = (FairseqIncrementalState,) + tuple( 49 | b for b in cls.__bases__ if b != FairseqIncrementalState 50 | ) 51 | return cls 52 | -------------------------------------------------------------------------------- /fairseq/logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FA-DAT/3599a3023c89ef16d2b33bf7c3a2b6f85aedaa4b/fairseq/logging/__init__.py -------------------------------------------------------------------------------- /fairseq/model_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import criterions, models, modules # noqa 7 | -------------------------------------------------------------------------------- /fairseq/model_parallel/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the criterions/ directory 11 | for file in sorted(os.listdir(os.path.dirname(__file__))): 12 | if file.endswith(".py") and not file.startswith("_"): 13 | module = file[: file.find(".py")] 14 | importlib.import_module("fairseq.model_parallel.criterions." + module) 15 | -------------------------------------------------------------------------------- /fairseq/model_parallel/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the models/ directory 11 | models_dir = os.path.dirname(__file__) 12 | for file in os.listdir(models_dir): 13 | path = os.path.join(models_dir, file) 14 | if ( 15 | not file.startswith("_") 16 | and not file.startswith(".") 17 | and (file.endswith(".py") or os.path.isdir(path)) 18 | ): 19 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 20 | module = importlib.import_module("fairseq.model_parallel.models." + model_name) 21 | -------------------------------------------------------------------------------- /fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import * # noqa 7 | -------------------------------------------------------------------------------- /fairseq/model_parallel/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import * # noqa 7 | -------------------------------------------------------------------------------- /fairseq/model_parallel/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | from .multihead_attention import ModelParallelMultiheadAttention 8 | from .transformer_layer import ( 9 | ModelParallelTransformerEncoderLayer, 10 | ModelParallelTransformerDecoderLayer, 11 | ) 12 | 13 | __all__ = [ 14 | "ModelParallelMultiheadAttention", 15 | "ModelParallelTransformerEncoderLayer", 16 | "ModelParallelTransformerDecoderLayer", 17 | ] 18 | -------------------------------------------------------------------------------- /fairseq/models/bart/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | -------------------------------------------------------------------------------- /fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .fairseq_encoder import FairseqEncoder 7 | 8 | 9 | class CompositeEncoder(FairseqEncoder): 10 | """ 11 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 12 | 13 | We run forward on each encoder and return a dictionary of outputs. The first 14 | encoder's dictionary is used for initialization. 15 | 16 | Args: 17 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 18 | """ 19 | 20 | def __init__(self, encoders): 21 | super().__init__(next(iter(encoders.values())).dictionary) 22 | self.encoders = encoders 23 | for key in self.encoders: 24 | self.add_module(key, self.encoders[key]) 25 | 26 | def forward(self, src_tokens, src_lengths): 27 | """ 28 | Args: 29 | src_tokens (LongTensor): tokens in the source language of shape 30 | `(batch, src_len)` 31 | src_lengths (LongTensor): lengths of each source sentence of shape 32 | `(batch)` 33 | 34 | Returns: 35 | dict: 36 | the outputs from each Encoder 37 | """ 38 | encoder_out = {} 39 | for key in self.encoders: 40 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 41 | return encoder_out 42 | 43 | def reorder_encoder_out(self, encoder_out, new_order): 44 | """Reorder encoder output according to new_order.""" 45 | for key in self.encoders: 46 | encoder_out[key] = self.encoders[key].reorder_encoder_out( 47 | encoder_out[key], new_order 48 | ) 49 | return encoder_out 50 | 51 | def max_positions(self): 52 | return min(self.encoders[key].max_positions() for key in self.encoders) 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | for key in self.encoders: 56 | self.encoders[key].upgrade_state_dict(state_dict) 57 | return state_dict 58 | -------------------------------------------------------------------------------- /fairseq/models/ema/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from .ema import EMA 10 | 11 | 12 | def build_ema(model, cfg, device): 13 | return EMA(model, cfg, device) 14 | 15 | 16 | # automatically import any Python files in the models/ema/ directory 17 | for file in sorted(os.listdir(os.path.dirname(__file__))): 18 | if file.endswith(".py") and not file.startswith("_"): 19 | file_name = file[: file.find(".py")] 20 | importlib.import_module("fairseq.models.ema." + file_name) 21 | -------------------------------------------------------------------------------- /fairseq/models/hubert/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hubert import * # noqa 7 | from .hubert_asr import * # noqa 8 | -------------------------------------------------------------------------------- /fairseq/models/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the models/huggingface/ directory 11 | models_dir = os.path.dirname(__file__) 12 | for file in os.listdir(models_dir): 13 | path = os.path.join(models_dir, file) 14 | if ( 15 | not file.startswith("_") 16 | and not file.startswith(".") 17 | and (file.endswith(".py") or os.path.isdir(path)) 18 | ): 19 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 20 | module = importlib.import_module("fairseq.models.huggingface." + model_name) 21 | -------------------------------------------------------------------------------- /fairseq/models/nat/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | from .fairseq_nat_model import * 8 | from .nonautoregressive_transformer import * 9 | from .nat_crf_transformer import * 10 | from .iterative_nonautoregressive_transformer import * 11 | from .cmlm_transformer import * 12 | from .levenshtein_transformer import * 13 | from .insertion_transformer import * 14 | -------------------------------------------------------------------------------- /fairseq/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | from .enc_dec import * # noqa 9 | from .model_camembert import * # noqa 10 | from .model_gottbert import * # noqa 11 | from .model_xlmr import * # noqa 12 | -------------------------------------------------------------------------------- /fairseq/models/roberta/model_camembert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | CamemBERT: a Tasty French Language Model 7 | """ 8 | 9 | from fairseq.models import register_model 10 | 11 | from .hub_interface import RobertaHubInterface 12 | from .model import RobertaModel 13 | 14 | 15 | @register_model("camembert") 16 | class CamembertModel(RobertaModel): 17 | @classmethod 18 | def hub_models(cls): 19 | return { 20 | "camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", 21 | "camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", 22 | "camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", 23 | "camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz", 24 | "camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz", 25 | "camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz", 26 | "camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz", 27 | "camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz", 28 | } 29 | 30 | @classmethod 31 | def from_pretrained( 32 | cls, 33 | model_name_or_path, 34 | checkpoint_file="model.pt", 35 | data_name_or_path=".", 36 | bpe="sentencepiece", 37 | **kwargs 38 | ): 39 | from fairseq import hub_utils 40 | 41 | x = hub_utils.from_pretrained( 42 | model_name_or_path, 43 | checkpoint_file, 44 | data_name_or_path, 45 | archive_map=cls.hub_models(), 46 | bpe=bpe, 47 | load_checkpoint_heads=True, 48 | **kwargs, 49 | ) 50 | return RobertaHubInterface(x["args"], x["task"], x["models"][0]) 51 | -------------------------------------------------------------------------------- /fairseq/models/roberta/model_gottbert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | GottBERT: a pure German Language Model 7 | """ 8 | 9 | from fairseq.models import register_model 10 | 11 | from .hub_interface import RobertaHubInterface 12 | from .model import RobertaModel 13 | 14 | 15 | @register_model("gottbert") 16 | class GottbertModel(RobertaModel): 17 | @classmethod 18 | def hub_models(cls): 19 | return { 20 | "gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz", 21 | } 22 | 23 | @classmethod 24 | def from_pretrained( 25 | cls, 26 | model_name_or_path, 27 | checkpoint_file="model.pt", 28 | data_name_or_path=".", 29 | bpe="hf_byte_bpe", 30 | bpe_vocab="vocab.json", 31 | bpe_merges="merges.txt", 32 | bpe_add_prefix_space=False, 33 | **kwargs 34 | ): 35 | from fairseq import hub_utils 36 | 37 | x = hub_utils.from_pretrained( 38 | model_name_or_path, 39 | checkpoint_file, 40 | data_name_or_path, 41 | archive_map=cls.hub_models(), 42 | bpe=bpe, 43 | load_checkpoint_heads=True, 44 | bpe_vocab=bpe_vocab, 45 | bpe_merges=bpe_merges, 46 | bpe_add_prefix_space=bpe_add_prefix_space, 47 | **kwargs, 48 | ) 49 | return RobertaHubInterface(x["args"], x["task"], x["models"][0]) 50 | -------------------------------------------------------------------------------- /fairseq/models/roberta/model_xlmr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Unsupervised Cross-lingual Representation Learning at Scale 7 | """ 8 | 9 | from fairseq.models import register_model 10 | 11 | from .hub_interface import RobertaHubInterface 12 | from .model import RobertaModel 13 | 14 | 15 | @register_model("xlmr") 16 | class XLMRModel(RobertaModel): 17 | @classmethod 18 | def hub_models(cls): 19 | return { 20 | "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz", 21 | "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz", 22 | "xlmr.xl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xl.tar.gz", 23 | "xlmr.xxl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xxl.tar.gz", 24 | } 25 | 26 | @classmethod 27 | def from_pretrained( 28 | cls, 29 | model_name_or_path, 30 | checkpoint_file="model.pt", 31 | data_name_or_path=".", 32 | bpe="sentencepiece", 33 | **kwargs 34 | ): 35 | from fairseq import hub_utils 36 | 37 | x = hub_utils.from_pretrained( 38 | model_name_or_path, 39 | checkpoint_file, 40 | data_name_or_path, 41 | archive_map=cls.hub_models(), 42 | bpe=bpe, 43 | load_checkpoint_heads=True, 44 | **kwargs, 45 | ) 46 | return RobertaHubInterface(x["args"], x["task"], x["models"][0]) 47 | -------------------------------------------------------------------------------- /fairseq/models/speech_to_speech/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .modules import * # noqa 7 | from .s2s_transformer import * # noqa 8 | -------------------------------------------------------------------------------- /fairseq/models/speech_to_speech/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from fairseq.models import FairseqEncoder 10 | from fairseq.models.transformer import Linear 11 | 12 | 13 | class CTCDecoder(FairseqEncoder): 14 | def __init__(self, dictionary, in_dim): 15 | super().__init__(dictionary) 16 | self.proj = nn.Linear(in_dim, len(dictionary)) 17 | 18 | def forward(self, src_tokens, src_lengths=None, **kwargs): 19 | encoder_out = self.proj(src_tokens) 20 | return {"encoder_out": encoder_out} 21 | 22 | 23 | class StackedEmbedding(nn.Embedding): 24 | """Embedding module that supports stacked units -> single embedding""" 25 | 26 | def __init__(self, num_embeddings, embed_dim, padding_idx, num_stacked=1): 27 | super().__init__(num_embeddings, embed_dim, padding_idx) 28 | # follow transformer.Embedding 29 | nn.init.normal_(self.weight, mean=0, std=embed_dim**-0.5) 30 | nn.init.constant_(self.weight[padding_idx], 0) 31 | 32 | self.offset = ( 33 | 4 # skip , , , , specific to fairseq dictionary 34 | ) 35 | self.vocab_size = num_embeddings - self.offset 36 | self.num_stacked = num_stacked 37 | 38 | if self.num_stacked > 1: 39 | self.project_in_dim = Linear(embed_dim * num_stacked, embed_dim, bias=False) 40 | 41 | def forward(self, input): 42 | if self.num_stacked == 1: 43 | return super().forward(input) 44 | 45 | # expand input indices 46 | mask = input >= self.offset 47 | stacked_input = [] 48 | cum_input = input.new_zeros(input.shape) 49 | for i in range(1, self.num_stacked + 1): 50 | div = pow(self.vocab_size, i) 51 | next_input = torch.remainder(input - self.offset - cum_input, div) 52 | cum_input += next_input 53 | next_input = torch.floor_divide(next_input, div // self.vocab_size) 54 | stacked_input.append((next_input + self.offset) * mask + input * ~mask) 55 | 56 | stacked_input = torch.stack(stacked_input[::-1], dim=2) 57 | embed = super().forward(stacked_input).view(input.size(0), input.size(1), -1) 58 | embed = self.project_in_dim(embed) 59 | return embed 60 | -------------------------------------------------------------------------------- /fairseq/models/speech_to_text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .berard import * # noqa 7 | from .convtransformer import * # noqa 8 | from .s2t_transformer import * # noqa 9 | from .xm_transformer import * # noqa 10 | from .s2t_conformer import * # noqa 11 | -------------------------------------------------------------------------------- /fairseq/models/text_to_speech/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .tacotron2 import * # noqa 7 | from .tts_transformer import * # noqa 8 | from .fastspeech2 import * # noqa 9 | -------------------------------------------------------------------------------- /fairseq/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | from .transformer_config import ( 8 | TransformerConfig, 9 | DEFAULT_MAX_SOURCE_POSITIONS, 10 | DEFAULT_MAX_TARGET_POSITIONS, 11 | DEFAULT_MIN_PARAMS_TO_WRAP, 12 | ) 13 | from .transformer_decoder import TransformerDecoder, TransformerDecoderBase, Linear 14 | from .transformer_encoder import TransformerEncoder, TransformerEncoderBase 15 | from .transformer_legacy import ( 16 | TransformerModel, 17 | base_architecture, 18 | tiny_architecture, 19 | transformer_iwslt_de_en, 20 | transformer_wmt_en_de, 21 | transformer_vaswani_wmt_en_de_big, 22 | transformer_vaswani_wmt_en_fr_big, 23 | transformer_wmt_en_de_big, 24 | transformer_wmt_en_de_big_t2t, 25 | ) 26 | from .transformer_base import TransformerModelBase, Embedding 27 | 28 | 29 | __all__ = [ 30 | "TransformerModelBase", 31 | "TransformerConfig", 32 | "TransformerDecoder", 33 | "TransformerDecoderBase", 34 | "TransformerEncoder", 35 | "TransformerEncoderBase", 36 | "TransformerModel", 37 | "Embedding", 38 | "Linear", 39 | "base_architecture", 40 | "tiny_architecture", 41 | "transformer_iwslt_de_en", 42 | "transformer_wmt_en_de", 43 | "transformer_vaswani_wmt_en_de_big", 44 | "transformer_vaswani_wmt_en_fr_big", 45 | "transformer_wmt_en_de_big", 46 | "transformer_wmt_en_de_big_t2t", 47 | "DEFAULT_MAX_SOURCE_POSITIONS", 48 | "DEFAULT_MAX_TARGET_POSITIONS", 49 | "DEFAULT_MIN_PARAMS_TO_WRAP", 50 | ] 51 | -------------------------------------------------------------------------------- /fairseq/models/wav2vec/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .wav2vec import * # noqa 7 | from .wav2vec2 import * # noqa 8 | from .wav2vec2_asr import * # noqa 9 | -------------------------------------------------------------------------------- /fairseq/models/wav2vec/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch.nn.functional as F 8 | 9 | 10 | def pad_to_multiple(x, multiple, dim=-1, value=0): 11 | # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 12 | if x is None: 13 | return None, 0 14 | tsz = x.size(dim) 15 | m = tsz / multiple 16 | remainder = math.ceil(m) * multiple - tsz 17 | if m.is_integer(): 18 | return x, 0 19 | pad_offset = (0,) * (-1 - dim) * 2 20 | 21 | return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder 22 | -------------------------------------------------------------------------------- /fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class BeamableMM(nn.Module): 11 | """This module provides an optimized MM for beam decoding with attention. 12 | 13 | It leverage the fact that the source-side of the input is replicated beam 14 | times and the target-side of the input is of width one. This layer speeds up 15 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 16 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 17 | """ 18 | 19 | def __init__(self, beam_size=None): 20 | super(BeamableMM, self).__init__() 21 | self.beam_size = beam_size 22 | 23 | def forward(self, input1, input2): 24 | if ( 25 | not self.training 26 | and self.beam_size is not None # test mode 27 | and input1.dim() == 3 # beam size is set 28 | and input1.size(1) # only support batched input 29 | == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn.modules.utils import _single 9 | from torch import Tensor 10 | 11 | 12 | class ConvTBC(torch.nn.Module): 13 | """1D convolution over an input of shape (time x batch x channel) 14 | 15 | The implementation uses gemm to perform the convolution. This implementation 16 | is faster than cuDNN for small kernel sizes. 17 | """ 18 | 19 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 20 | super(ConvTBC, self).__init__() 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | self.kernel_size = _single(kernel_size) 24 | self.padding = _single(padding) 25 | 26 | self.weight = torch.nn.Parameter( 27 | torch.Tensor(self.kernel_size[0], in_channels, out_channels) 28 | ) 29 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | nn.init.xavier_normal_(self.weight) 35 | nn.init.zeros_(self.bias) 36 | 37 | def conv_tbc(self, input: Tensor): 38 | return torch.conv_tbc( 39 | input.contiguous(), self.weight, self.bias, self.padding[0] 40 | ) 41 | 42 | def forward(self, input: Tensor): 43 | return self.conv_tbc(input) 44 | 45 | def __repr__(self): 46 | s = ( 47 | "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}" 48 | ", padding={padding}" 49 | ) 50 | if self.bias is None: 51 | s += ", bias=False" 52 | s += ")" 53 | return s.format(name=self.__class__.__name__, **self.__dict__) 54 | -------------------------------------------------------------------------------- /fairseq/modules/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): 15 | lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) 16 | return F.nll_loss( 17 | lprobs, 18 | target, 19 | ignore_index=ignore_index, 20 | reduction=reduction, 21 | ) 22 | 23 | 24 | try: 25 | import xentropy_cuda 26 | from apex.contrib import xentropy 27 | 28 | def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): 29 | if logits.device == torch.device("cpu"): 30 | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) 31 | else: 32 | if not getattr(cross_entropy, "_has_logged_once", False): 33 | logger.info("using fused cross entropy") 34 | cross_entropy._has_logged_once = True 35 | 36 | half_to_float = logits.dtype == torch.half 37 | losses = xentropy.SoftmaxCrossEntropyLoss.apply( 38 | logits, 39 | target, 40 | 0.0, 41 | ignore_index, 42 | half_to_float, 43 | ) 44 | if reduction == "sum": 45 | return losses.sum() 46 | elif reduction == "mean": 47 | if ignore_index >= 0: 48 | return losses.sum() / target.ne(ignore_index).sum() 49 | else: 50 | return losses.mean() 51 | elif reduction == "none": 52 | return losses 53 | else: 54 | raise NotImplementedError 55 | 56 | except ImportError: 57 | 58 | def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): 59 | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) 60 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamicconv_layer import DynamicconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector 12 | dynamicconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); 13 | 14 | std::vector dynamicconv_cuda_backward( 15 | at::Tensor gradOutput, 16 | int padding_l, 17 | at::Tensor input, 18 | at::Tensor filters); 19 | 20 | #define CHECK_CUDA(x) \ 21 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) \ 23 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 24 | #define CHECK_INPUT(x) \ 25 | CHECK_CUDA(x); \ 26 | CHECK_CONTIGUOUS(x) 27 | 28 | std::vector 29 | dynamicconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { 30 | CHECK_INPUT(input); 31 | CHECK_INPUT(filters); 32 | 33 | return dynamicconv_cuda_forward(input, filters, padding_l); 34 | } 35 | 36 | std::vector dynamicconv_backward( 37 | at::Tensor gradOutput, 38 | int padding_l, 39 | at::Tensor input, 40 | at::Tensor filters) { 41 | CHECK_INPUT(gradOutput); 42 | CHECK_INPUT(input); 43 | CHECK_INPUT(filters); 44 | 45 | return dynamicconv_cuda_backward(gradOutput, padding_l, input, filters); 46 | } 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 49 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); 50 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #define SHFL_MASK 0xffffffff 27 | 28 | template 29 | __global__ void dynamicconv_forward_kernel( 30 | const scalar_t* input, 31 | const scalar_t* weight, 32 | int minibatch, 33 | int sequenceLength, 34 | int numFeatures, 35 | int numFiltersInBlock, 36 | int numHeads, 37 | scalar_t* output); 38 | 39 | template 40 | __global__ void dynamicconv_backward_kernel( 41 | const scalar_t* gradOutput, // B * C * T 42 | const scalar_t* input, // B * C * T 43 | const scalar_t* weight, 44 | int minibatch, 45 | int sequenceLength, 46 | int numFeatures, 47 | int numFiltersInBlock, 48 | int numHeads, 49 | scalar_t* gradWeight, 50 | scalar_t* gradInput); // B * H * k * T 51 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector 5 | dynamicconv_cpu_forward(float* input, float* filters, int padding_l); 6 | 7 | std::vector dynamicconv_cpu_backward( 8 | float* gradOutput, 9 | int padding_l, 10 | float* input, 11 | float* filters); 12 | 13 | std::vector 14 | dynamicconv_forward(float* input, float* filters, int padding_l) { 15 | return dynamicconv_cpu_forward(input, filters, padding_l); 16 | } 17 | 18 | std::vector dynamicconv_backward( 19 | float* gradOutput, 20 | int padding_l, 21 | float* input, 22 | float* filters) { 23 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); 24 | } 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 27 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); 28 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); 29 | } 30 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | 11 | setup( 12 | name="dynamicconv_layer", 13 | ext_modules=[ 14 | CUDAExtension( 15 | name="dynamicconv_cuda", 16 | sources=[ 17 | "dynamicconv_cuda.cpp", 18 | "dynamicconv_cuda_kernel.cu", 19 | ], 20 | ), 21 | ], 22 | cmdclass={"build_ext": BuildExtension}, 23 | ) 24 | -------------------------------------------------------------------------------- /fairseq/modules/fairseq_dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List, Optional 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class FairseqDropout(nn.Module): 17 | def __init__(self, p, module_name=None): 18 | super().__init__() 19 | self.p = p 20 | self.module_name = module_name 21 | self.apply_during_inference = False 22 | 23 | def forward(self, x, inplace: bool = False): 24 | if self.p > 0 and (self.training or self.apply_during_inference): 25 | return F.dropout(x, p=self.p, training=True, inplace=inplace) 26 | else: 27 | return x 28 | 29 | def make_generation_fast_( 30 | self, 31 | name: str, 32 | retain_dropout: bool = False, 33 | retain_dropout_modules: Optional[List[str]] = None, 34 | **kwargs 35 | ): 36 | if retain_dropout: 37 | if retain_dropout_modules is not None and self.module_name is None: 38 | logger.warning( 39 | "Cannot enable dropout during inference for module {} " 40 | "because module_name was not set".format(name) 41 | ) 42 | elif ( 43 | retain_dropout_modules is None # if None, apply to all modules 44 | or self.module_name in retain_dropout_modules 45 | ): 46 | logger.info( 47 | "Enabling dropout during inference for module: {}".format(name) 48 | ) 49 | self.apply_during_inference = True 50 | else: 51 | logger.info("Disabling dropout for module: {}".format(name)) 52 | -------------------------------------------------------------------------------- /fairseq/modules/fp32_batch_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | batch norm done in fp32 (for fp16 training) 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class Fp32BatchNorm(nn.Module): 13 | def __init__(self, sync=False, *args, **kwargs): 14 | super().__init__() 15 | 16 | if sync: 17 | from fairseq.distributed import utils 18 | 19 | if utils.get_global_world_size() == 1: 20 | sync = False 21 | 22 | if sync: 23 | self.bn = nn.SyncBatchNorm(*args, **kwargs) 24 | else: 25 | self.bn = nn.BatchNorm1d(*args, **kwargs) 26 | 27 | self.sync = sync 28 | 29 | def forward(self, input): 30 | if self.bn.running_mean.dtype != torch.float: 31 | if self.sync: 32 | self.bn.running_mean = self.bn.running_mean.float() 33 | self.bn.running_var = self.bn.running_var.float() 34 | if self.bn.affine: 35 | try: 36 | self.bn.weight = self.bn.weight.float() 37 | self.bn.bias = self.bn.bias.float() 38 | except: 39 | self.bn.float() 40 | else: 41 | self.bn.float() 42 | 43 | output = self.bn(input.float()) 44 | return output.type_as(input) 45 | -------------------------------------------------------------------------------- /fairseq/modules/fp32_group_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Layer norm done in fp32 (for fp16 training) 7 | """ 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Fp32GroupNorm(nn.GroupNorm): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | def forward(self, input): 18 | output = F.group_norm( 19 | input.float(), 20 | self.num_groups, 21 | self.weight.float() if self.weight is not None else None, 22 | self.bias.float() if self.bias is not None else None, 23 | self.eps, 24 | ) 25 | return output.type_as(input) 26 | -------------------------------------------------------------------------------- /fairseq/modules/fp32_instance_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Layer norm done in fp32 (for fp16 training) 7 | """ 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Fp32InstanceNorm(nn.InstanceNorm1d): 14 | def __init__(self, *args, **kwargs): 15 | self.transpose_last = "transpose_last" in kwargs and kwargs["transpose_last"] 16 | if "transpose_last" in kwargs: 17 | del kwargs["transpose_last"] 18 | super().__init__(*args, **kwargs) 19 | 20 | def forward(self, input): 21 | if self.transpose_last: 22 | input = input.transpose(1, 2) 23 | output = F.instance_norm( 24 | input.float(), 25 | running_mean=self.running_mean, 26 | running_var=self.running_var, 27 | weight=self.weight.float() if self.weight is not None else None, 28 | bias=self.bias.float() if self.bias is not None else None, 29 | use_input_stats=self.training or not self.track_running_stats, 30 | momentum=self.momentum, 31 | eps=self.eps, 32 | ) 33 | if self.transpose_last: 34 | output = output.transpose(1, 2) 35 | return output.type_as(input) 36 | -------------------------------------------------------------------------------- /fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def gelu_accurate(x): 17 | if not hasattr(gelu_accurate, "_a"): 18 | gelu_accurate._a = math.sqrt(2 / math.pi) 19 | return ( 20 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 21 | ) 22 | 23 | 24 | def gelu(x: torch.Tensor) -> torch.Tensor: 25 | return torch.nn.functional.gelu(x.float()).type_as(x) 26 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class GradMultiply(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.scale = scale 13 | res = x.new(x) 14 | return res 15 | 16 | @staticmethod 17 | def backward(ctx, grad): 18 | return grad * ctx.scale, None 19 | -------------------------------------------------------------------------------- /fairseq/modules/layer_drop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | LayerDrop as described in https://arxiv.org/abs/1909.11556. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class LayerDropModuleList(nn.ModuleList): 14 | """ 15 | A LayerDrop implementation based on :class:`torch.nn.ModuleList`. 16 | 17 | We refresh the choice of which layers to drop every time we iterate 18 | over the LayerDropModuleList instance. During evaluation we always 19 | iterate over all layers. 20 | 21 | Usage:: 22 | 23 | layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) 24 | for layer in layers: # this might iterate over layers 1 and 3 25 | x = layer(x) 26 | for layer in layers: # this might iterate over all layers 27 | x = layer(x) 28 | for layer in layers: # this might not iterate over any layers 29 | x = layer(x) 30 | 31 | Args: 32 | p (float): probability of dropping out each layer 33 | modules (iterable, optional): an iterable of modules to add 34 | """ 35 | 36 | def __init__(self, p, modules=None): 37 | super().__init__(modules) 38 | self.p = p 39 | 40 | def __iter__(self): 41 | dropout_probs = torch.empty(len(self)).uniform_() 42 | for i, m in enumerate(super().__iter__()): 43 | if not self.training or (dropout_probs[i] > self.p): 44 | yield m 45 | -------------------------------------------------------------------------------- /fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | try: 11 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm 12 | 13 | has_fused_layernorm = True 14 | 15 | class FusedLayerNorm(_FusedLayerNorm): 16 | @torch.jit.unused 17 | def forward(self, x): 18 | if not x.is_cuda: 19 | return super().forward(x) 20 | else: 21 | with torch.cuda.device(x.device): 22 | return super().forward(x) 23 | 24 | except ImportError: 25 | has_fused_layernorm = False 26 | 27 | 28 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 29 | if torch.jit.is_scripting() or torch.jit.is_tracing(): 30 | export = True 31 | if not export and torch.cuda.is_available() and has_fused_layernorm: 32 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 33 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 34 | 35 | 36 | class Fp32LayerNorm(nn.LayerNorm): 37 | def __init__(self, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | 40 | def forward(self, input): 41 | output = F.layer_norm( 42 | input.float(), 43 | self.normalized_shape, 44 | self.weight.float() if self.weight is not None else None, 45 | self.bias.float() if self.bias is not None else None, 46 | self.eps, 47 | ) 48 | return output.type_as(input) 49 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Dict, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from fairseq import utils 12 | from torch import Tensor 13 | 14 | 15 | class LearnedPositionalEmbedding(nn.Embedding): 16 | """ 17 | This module learns positional embeddings up to a fixed maximum size. 18 | Padding ids are ignored by either offsetting based on padding_idx 19 | or by setting padding_idx to None and ensuring that the appropriate 20 | position ids are passed to the forward function. 21 | """ 22 | 23 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): 24 | super().__init__(num_embeddings, embedding_dim, padding_idx) 25 | self.onnx_trace = False 26 | if self.padding_idx is not None: 27 | self.max_positions = self.num_embeddings - self.padding_idx - 1 28 | else: 29 | self.max_positions = self.num_embeddings 30 | 31 | def forward( 32 | self, 33 | input: Tensor, 34 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 35 | positions: Optional[Tensor] = None, 36 | ): 37 | """Input is expected to be of size [bsz x seqlen].""" 38 | assert (positions is None) or ( 39 | self.padding_idx is None 40 | ), "If positions is pre-computed then padding_idx should not be set." 41 | 42 | if positions is None: 43 | if incremental_state is not None: 44 | # positions is the same for every token when decoding a single step 45 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 46 | positions = torch.zeros( 47 | (1, 1), device=input.device, dtype=input.dtype 48 | ).fill_(int(self.padding_idx + input.size(1))) 49 | else: 50 | positions = utils.make_positions( 51 | input, self.padding_idx, onnx_trace=self.onnx_trace 52 | ) 53 | return F.embedding( 54 | positions, 55 | self.weight, 56 | self.padding_idx, 57 | self.max_norm, 58 | self.norm_type, 59 | self.scale_grad_by_freq, 60 | self.sparse, 61 | ) 62 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lightconv_layer import LightconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector 12 | lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); 13 | 14 | std::vector lightconv_cuda_backward( 15 | at::Tensor gradOutput, 16 | int padding_l, 17 | at::Tensor input, 18 | at::Tensor filters); 19 | 20 | #define CHECK_CUDA(x) \ 21 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) \ 23 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 24 | #define CHECK_INPUT(x) \ 25 | CHECK_CUDA(x); \ 26 | CHECK_CONTIGUOUS(x) 27 | 28 | std::vector 29 | lightconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { 30 | CHECK_INPUT(input); 31 | CHECK_INPUT(filters); 32 | 33 | return lightconv_cuda_forward(input, filters, padding_l); 34 | } 35 | 36 | std::vector lightconv_backward( 37 | at::Tensor gradOutput, 38 | int padding_l, 39 | at::Tensor input, 40 | at::Tensor filters) { 41 | CHECK_INPUT(gradOutput); 42 | CHECK_INPUT(input); 43 | CHECK_INPUT(filters); 44 | 45 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters); 46 | } 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 49 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); 50 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | 24 | #define SHFL_MASK 0xffffffff 25 | 26 | template 27 | __global__ void lightconv_forward_kernel( 28 | const scalar_t* input, 29 | const scalar_t* filters, 30 | int minibatch, 31 | int sequenceLength, 32 | int numFeatures, 33 | int numFiltersInBlock, 34 | scalar_t* output); 35 | 36 | template 37 | __global__ void lightconv_grad_wrt_input_kernel( 38 | const scalar_t* input, 39 | const scalar_t* filters, 40 | int minibatch, 41 | int sequenceLength, 42 | int numFeatures, 43 | int numFiltersInBlock, 44 | scalar_t* output); 45 | 46 | template 47 | __global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( 48 | const scalar_t* input, 49 | const scalar_t* gradInput, 50 | int minibatch, 51 | int sequenceLength, 52 | int numFeatures, 53 | int numFiltersInBlock, 54 | int numHeads, 55 | float* output); 56 | 57 | template 58 | __global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( 59 | const float* input, 60 | const int minibatch, 61 | const int numFiltersInBlock, 62 | scalar_t* output); 63 | 64 | template 65 | __global__ void lightconv_grad_wrt_weights_firstpass_kernel( 66 | const scalar_t* input, 67 | const scalar_t* gradInput, 68 | int minibatch, 69 | int sequenceLength, 70 | int numFeatures, 71 | int numFiltersInBlock, 72 | float* output); 73 | 74 | template 75 | __global__ void lightconv_grad_wrt_weights_secondpass_kernel( 76 | const float* input, 77 | const int minibatch, 78 | const int numFiltersInBlock, 79 | scalar_t* output); 80 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | 11 | setup( 12 | name="lightconv_layer", 13 | ext_modules=[ 14 | CUDAExtension( 15 | "lightconv_cuda", 16 | [ 17 | "lightconv_cuda.cpp", 18 | "lightconv_cuda_kernel.cu", 19 | ], 20 | ), 21 | ], 22 | cmdclass={"build_ext": BuildExtension}, 23 | ) 24 | -------------------------------------------------------------------------------- /fairseq/modules/lstm_cell_with_zoneout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class LSTMCellWithZoneOut(nn.Module): 10 | """ 11 | Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations 12 | https://arxiv.org/abs/1606.01305 13 | """ 14 | 15 | def __init__( 16 | self, prob: float, input_size: int, hidden_size: int, bias: bool = True 17 | ): 18 | super(LSTMCellWithZoneOut, self).__init__() 19 | self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) 20 | self.prob = prob 21 | if prob > 1.0 or prob < 0.0: 22 | raise ValueError( 23 | "zoneout probability must be in the range from " "0.0 to 1.0." 24 | ) 25 | 26 | def zoneout(self, h, next_h, prob): 27 | if isinstance(h, tuple): 28 | return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]) 29 | 30 | if self.training: 31 | mask = h.new_zeros(*h.size()).bernoulli_(prob) 32 | return mask * h + (1 - mask) * next_h 33 | 34 | return prob * h + (1 - prob) * next_h 35 | 36 | def forward(self, x, h): 37 | return self.zoneout(h, self.lstm_cell(x, h), self.prob) 38 | -------------------------------------------------------------------------------- /fairseq/modules/positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from .learned_positional_embedding import LearnedPositionalEmbedding 9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 10 | 11 | 12 | def PositionalEmbedding( 13 | num_embeddings: int, 14 | embedding_dim: int, 15 | padding_idx: int, 16 | learned: bool = False, 17 | ): 18 | if learned: 19 | # if padding_idx is specified then offset the embedding ids by 20 | # this index and adjust num_embeddings appropriately 21 | # TODO: The right place for this offset would be inside 22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation. 23 | if padding_idx is not None: 24 | num_embeddings = num_embeddings + padding_idx + 1 25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) 27 | if padding_idx is not None: 28 | nn.init.constant_(m.weight[padding_idx], 0) 29 | else: 30 | m = SinusoidalPositionalEmbedding( 31 | embedding_dim, 32 | padding_idx, 33 | init_size=num_embeddings + padding_idx + 1, 34 | ) 35 | return m 36 | -------------------------------------------------------------------------------- /fairseq/modules/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FA-DAT/3599a3023c89ef16d2b33bf7c3a2b6f85aedaa4b/fairseq/modules/quantization/__init__.py -------------------------------------------------------------------------------- /fairseq/modules/quantization/pq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import SizeTracker, get_param, attrsetter, quantize_model_ # NOQA 7 | -------------------------------------------------------------------------------- /fairseq/modules/quantization/pq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qconv import PQConv2d # NOQA 7 | from .qemb import PQEmbedding # NOQA 8 | from .qlinear import PQLinear # NOQA 9 | -------------------------------------------------------------------------------- /fairseq/modules/quantization/quantization_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | def parse_config_yaml(yaml_data): 8 | # Initialize to default options. 9 | quantization_options = { 10 | "n_centroids": { 11 | "Linear": ["in_features", {"*": 256}], 12 | "Embedding": ["embedding_dim", {"*": 256}], 13 | }, 14 | "block_sizes": { 15 | "Linear": ["fuzzy_name", {"fc": 8, "attn": 4, "emb": 4}], 16 | "Embedding": ["fuzzy_name", {"emb": 8}], 17 | }, 18 | "layers_to_quantize": [ 19 | "decoder\\.layers\\.\\d+\\.fc[12]", 20 | "decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]", 21 | "decoder\\.layers\\.\\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)", 22 | ], 23 | } 24 | 25 | if "n_centroids" in yaml_data: 26 | quantization_options["n_centroids"] = { 27 | layer: convert_yaml_to_tuple(layer_data) 28 | for layer, layer_data in yaml_data["n_centroids"].items() 29 | } 30 | if "block_sizes" in yaml_data: 31 | quantization_options["block_sizes"] = { 32 | layer: convert_yaml_to_tuple(layer_data) 33 | for layer, layer_data in yaml_data["block_sizes"].items() 34 | } 35 | if "layers_to_quantize" in yaml_data: 36 | quantization_options["layers_to_quantize"] = yaml_data["layers_to_quantize"] 37 | 38 | return quantization_options 39 | 40 | 41 | def convert_yaml_to_tuple(yaml_dictionary): 42 | """Converts a yaml dictionary with two keys: `key` and `value` into a two 43 | argument tuple of those values.""" 44 | return (yaml_dictionary["key"], yaml_dictionary["value"]) 45 | -------------------------------------------------------------------------------- /fairseq/modules/quantization/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import quantize_model_ # NOQA 7 | -------------------------------------------------------------------------------- /fairseq/modules/quantization/scalar/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qact import ActivationQuantizer # NOQA 7 | from .qconv import IntConv2d # NOQA 8 | from .qemb import IntEmbedding # NOQA 9 | from .qlinear import IntLinear # NOQA 10 | -------------------------------------------------------------------------------- /fairseq/modules/quantization/scalar/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | try: 9 | import torch.ao.quantization as quantization 10 | except ImportError: 11 | import torch.quantization as quantization 12 | 13 | 14 | def emulate_int(w, bits, method, scale=None, zero_point=None): 15 | q = globals()[f"emulate_int8_{method}"] 16 | return q(w, scale=scale, zero_point=zero_point, bits=bits) 17 | 18 | 19 | def quantize(w, scale, zero_point, bits=8): 20 | # In the default behavior, max_val = 255. 21 | max_val = 2**bits - 1 22 | return ( 23 | torch.clamp(torch.round(w / scale + zero_point), 0, max_val) - zero_point 24 | ) * scale 25 | 26 | 27 | def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8): 28 | if scale is None: 29 | obs = quantization.observer.HistogramObserver() 30 | obs.to(device=w.device) 31 | _ = obs(w.float()) 32 | scale, zero_point = obs.calculate_qparams() 33 | scale = scale.cuda().type_as(w) 34 | zero_point = zero_point.cuda().type_as(w) 35 | return quantize(w, scale, zero_point, bits=bits), scale, zero_point 36 | 37 | 38 | def emulate_int8_channel(w, scale=None, zero_point=None, bits=8): 39 | if scale is None: 40 | obs = quantization.observer.PerChannelMinMaxObserver( 41 | ch_axis=-1, qscheme=torch.per_channel_symmetric 42 | ) 43 | obs.to(device=w.device) 44 | _ = obs(w) 45 | scale, zero_point, ch_axis = obs.get_qparams() 46 | scale = scale.cuda().type_as(w) 47 | zero_point = zero_point.cuda().type_as(w) 48 | return quantize(w, scale, zero_point, bits=bits), scale, zero_point 49 | 50 | 51 | def emulate_int8_tensor(w, scale=None, zero_point=None, bits=8): 52 | if scale is None: 53 | obs = quantization.observer.MinMaxObserver() 54 | obs.to(device=w.device) 55 | _ = obs(w) 56 | scale, zero_point = obs.calculate_qparams() 57 | scale = scale.cuda().type_as(w) 58 | zero_point = zero_point.cuda().type_as(w) 59 | return quantize(w, scale, zero_point, bits=bits), scale, zero_point 60 | -------------------------------------------------------------------------------- /fairseq/modules/rotary_positional_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class RotaryPositionalEmbedding(torch.nn.Module): 5 | def __init__(self, dim, base=10000, precision=torch.half): 6 | """Rotary positional embedding 7 | Reference : https://blog.eleuther.ai/rotary-embeddings/ 8 | Paper: https://arxiv.org/pdf/2104.09864.pdf 9 | Args: 10 | dim: Dimension of embedding 11 | base: Base value for exponential 12 | precision: precision to use for numerical values 13 | """ 14 | super().__init__() 15 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 16 | self.register_buffer("inv_freq", inv_freq) 17 | self.seq_len_cached = None 18 | self.cos_cached = None 19 | self.sin_cached = None 20 | self.precision = precision 21 | 22 | def forward(self, x, seq_len=None): 23 | """ 24 | Args: 25 | x: Input x with T X B X C 26 | seq_len: Sequence length of input x 27 | """ 28 | if seq_len != self.seq_len_cached: 29 | self.seq_len_cached = seq_len 30 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 31 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 32 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 33 | self.cos_cached = emb.cos()[:, None, None, :] 34 | self.sin_cached = emb.sin()[:, None, None, :] 35 | return self.cos_cached, self.sin_cached 36 | 37 | 38 | # rotary pos emb helpers: 39 | def rotate_half(x): 40 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 41 | return torch.cat( 42 | (-x2, x1), dim=x1.ndim - 1 43 | ) # dim=-1 triggers a bug in earlier torch versions 44 | 45 | 46 | def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): 47 | cos, sin = ( 48 | cos[offset : q.shape[0] + offset, ...], 49 | sin[offset : q.shape[0] + offset, ...], 50 | ) 51 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 52 | -------------------------------------------------------------------------------- /fairseq/modules/same_pad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from torch import nn 8 | 9 | 10 | class SamePad(nn.Module): 11 | def __init__(self, kernel_size, causal=False): 12 | super().__init__() 13 | if causal: 14 | self.remove = kernel_size - 1 15 | else: 16 | self.remove = 1 if kernel_size % 2 == 0 else 0 17 | 18 | def forward(self, x): 19 | if self.remove > 0: 20 | x = x[:, :, : -self.remove] 21 | return x 22 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import torch 8 | 9 | 10 | class ScalarBias(torch.autograd.Function): 11 | """ 12 | Adds a vector of scalars, used in self-attention mechanism to allow 13 | the model to optionally attend to this vector instead of the past 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, input, dim, bias_init): 18 | size = list(input.size()) 19 | size[dim] += 1 20 | output = input.new(*size).fill_(bias_init) 21 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 22 | ctx.dim = dim 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad): 27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 28 | 29 | 30 | def scalar_bias(input, dim, bias_init=0): 31 | return ScalarBias.apply(input, dim, bias_init) 32 | -------------------------------------------------------------------------------- /fairseq/modules/sparse_transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.modules import TransformerSentenceEncoderLayer 7 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention 8 | 9 | 10 | class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): 11 | """ 12 | Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) 13 | """ 14 | 15 | def __init__( 16 | self, 17 | embedding_dim: int = 768, 18 | ffn_embedding_dim: int = 3072, 19 | num_attention_heads: int = 8, 20 | dropout: float = 0.1, 21 | attention_dropout: float = 0.1, 22 | activation_dropout: float = 0.1, 23 | activation_fn: str = "relu", 24 | export: bool = False, 25 | is_bidirectional: bool = True, 26 | stride: int = 32, 27 | expressivity: int = 8, 28 | ) -> None: 29 | 30 | super().__init__( 31 | embedding_dim, 32 | ffn_embedding_dim, 33 | num_attention_heads, 34 | dropout, 35 | attention_dropout, 36 | activation_dropout, 37 | activation_fn, 38 | export, 39 | ) 40 | 41 | self.self_attn = SparseMultiheadAttention( 42 | self.embedding_dim, 43 | num_attention_heads, 44 | dropout=attention_dropout, 45 | add_bias_kv=False, 46 | add_zero_attn=False, 47 | self_attention=True, 48 | is_bidirectional=is_bidirectional, 49 | stride=stride, 50 | expressivity=expressivity, 51 | ) 52 | -------------------------------------------------------------------------------- /fairseq/modules/transpose_last.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | transpose last 2 dimensions of the input 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | 12 | class TransposeLast(nn.Module): 13 | def __init__(self, deconstruct_idx=None): 14 | super().__init__() 15 | self.deconstruct_idx = deconstruct_idx 16 | 17 | def forward(self, x): 18 | if self.deconstruct_idx is not None: 19 | x = x[self.deconstruct_idx] 20 | return x.transpose(-2, -1) 21 | -------------------------------------------------------------------------------- /fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 10 | """unfold T x B x C to T x B x C x K""" 11 | if kernel_size > 1: 12 | T, B, C = x.size() 13 | x = F.pad( 14 | x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value 15 | ) 16 | x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C)) 17 | else: 18 | x = x.unsqueeze(3) 19 | return x 20 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | from fairseq.optim.bmuf import FairseqBMUF # noqa 12 | from fairseq.optim.fairseq_optimizer import ( # noqa 13 | FairseqOptimizer, 14 | LegacyFairseqOptimizer, 15 | ) 16 | from fairseq.optim.amp_optimizer import AMPOptimizer 17 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 18 | from fairseq.optim.shard import shard_ 19 | from omegaconf import DictConfig 20 | 21 | __all__ = [ 22 | "AMPOptimizer", 23 | "FairseqOptimizer", 24 | "FP16Optimizer", 25 | "MemoryEfficientFP16Optimizer", 26 | "shard_", 27 | ] 28 | 29 | ( 30 | _build_optimizer, 31 | register_optimizer, 32 | OPTIMIZER_REGISTRY, 33 | OPTIMIZER_DATACLASS_REGISTRY, 34 | ) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) 35 | 36 | 37 | def build_optimizer(cfg: DictConfig, params, *extra_args, **extra_kwargs): 38 | if all(isinstance(p, dict) for p in params): 39 | params = [t for p in params for t in p.values()] 40 | params = list(filter(lambda p: p.requires_grad, params)) 41 | return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) 42 | 43 | 44 | # automatically import any Python files in the optim/ directory 45 | for file in sorted(os.listdir(os.path.dirname(__file__))): 46 | if file.endswith(".py") and not file.startswith("_"): 47 | file_name = file[: file.find(".py")] 48 | importlib.import_module("fairseq.optim." + file_name) 49 | -------------------------------------------------------------------------------- /fairseq/optim/adadelta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import LegacyFairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer("adadelta") 12 | class Adadelta(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', 22 | help='coefficient used for computing a running average of squared gradients') 23 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', 24 | help='term added to the denominator to improve numerical stability') 25 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 26 | help='weight decay') 27 | parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') 28 | # fmt: on 29 | 30 | @property 31 | def optimizer_config(self): 32 | """ 33 | Return a kwarg dictionary that will be used to override optimizer 34 | args stored in checkpoints. This allows us to load a checkpoint and 35 | resume training using a different set of optimizer args, e.g., with a 36 | different learning rate. 37 | """ 38 | return { 39 | "lr": self.args.lr[0], 40 | "rho": self.args.adadelta_rho, 41 | "eps": self.args.adadelta_eps, 42 | "weight_decay": self.args.weight_decay, 43 | } 44 | 45 | @property 46 | def supports_flat_params(self): 47 | return True 48 | -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import LegacyFairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer("adagrad") 12 | class Adagrad(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 22 | help='weight decay') 23 | # fmt: on 24 | 25 | @property 26 | def optimizer_config(self): 27 | """ 28 | Return a kwarg dictionary that will be used to override optimizer 29 | args stored in checkpoints. This allows us to load a checkpoint and 30 | resume training using a different set of optimizer args, e.g., with a 31 | different learning rate. 32 | """ 33 | return { 34 | "lr": self.args.lr[0], 35 | "weight_decay": self.args.weight_decay, 36 | } 37 | 38 | @property 39 | def supports_flat_params(self): 40 | return False 41 | -------------------------------------------------------------------------------- /fairseq/optim/fused_lamb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.optim import LegacyFairseqOptimizer, register_optimizer 7 | 8 | 9 | @register_optimizer("lamb") 10 | class FairseqLAMB(LegacyFairseqOptimizer): 11 | """LAMB optimizer.""" 12 | 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | try: 16 | from apex.optimizers import FusedLAMB 17 | 18 | self._optimizer = FusedLAMB(params, **self.optimizer_config) 19 | except ImportError: 20 | raise ImportError("Please install apex to use LAMB optimizer") 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add optimizer-specific arguments to the parser.""" 25 | # fmt: off 26 | parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', 27 | help='betas for LAMB optimizer') 28 | parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', 29 | help='epsilon for LAMB optimizer') 30 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 31 | help='weight decay') 32 | # fmt: on 33 | 34 | @property 35 | def optimizer_config(self): 36 | """ 37 | Return a kwarg dictionary that will be used to override optimizer 38 | args stored in checkpoints. This allows us to load a checkpoint and 39 | resume training using a different set of optimizer args, e.g., with a 40 | different learning rate. 41 | """ 42 | return { 43 | "lr": self.args.lr[0], 44 | "betas": eval(self.args.lamb_betas), 45 | "eps": self.args.lamb_eps, 46 | "weight_decay": self.args.weight_decay, 47 | } 48 | 49 | @property 50 | def supports_flat_params(self): 51 | return False 52 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa 12 | FairseqLRScheduler, 13 | LegacyFairseqLRScheduler, 14 | ) 15 | from omegaconf import DictConfig 16 | 17 | 18 | ( 19 | build_lr_scheduler_, 20 | register_lr_scheduler, 21 | LR_SCHEDULER_REGISTRY, 22 | LR_SCHEDULER_DATACLASS_REGISTRY, 23 | ) = registry.setup_registry( 24 | "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" 25 | ) 26 | 27 | 28 | def build_lr_scheduler(cfg: DictConfig, optimizer): 29 | return build_lr_scheduler_(cfg, optimizer) 30 | 31 | 32 | # automatically import any Python files in the optim/lr_scheduler/ directory 33 | for file in sorted(os.listdir(os.path.dirname(__file__))): 34 | if file.endswith(".py") and not file.startswith("_"): 35 | file_name = file[: file.find(".py")] 36 | importlib.import_module("fairseq.optim.lr_scheduler." + file_name) 37 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from argparse import Namespace 7 | 8 | from fairseq.dataclass.utils import gen_parser_from_dataclass 9 | from fairseq.optim import FairseqOptimizer 10 | 11 | 12 | class FairseqLRScheduler(object): 13 | def __init__(self, cfg, optimizer): 14 | super().__init__() 15 | if optimizer is not None and not isinstance(optimizer, FairseqOptimizer): 16 | raise ValueError("optimizer must be an instance of FairseqOptimizer") 17 | self.cfg = cfg 18 | self.optimizer = optimizer 19 | self.best = None 20 | 21 | @classmethod 22 | def add_args(cls, parser): 23 | """Add arguments to the parser for this LR scheduler.""" 24 | dc = getattr(cls, "__dataclass", None) 25 | if dc is not None: 26 | gen_parser_from_dataclass(parser, dc()) 27 | 28 | def state_dict(self): 29 | """Return the LR scheduler state dict.""" 30 | return {"best": self.best} 31 | 32 | def load_state_dict(self, state_dict): 33 | """Load an LR scheduler state dict.""" 34 | self.best = state_dict["best"] 35 | 36 | def step_begin_epoch(self, epoch): 37 | """Update the learning rate at the beginning of the given epoch.""" 38 | pass 39 | 40 | def step(self, epoch, val_loss=None): 41 | """Update the learning rate at the end of the given epoch.""" 42 | if val_loss is not None: 43 | if self.best is None: 44 | self.best = val_loss 45 | else: 46 | self.best = min(self.best, val_loss) 47 | 48 | def step_update(self, num_updates): 49 | """Update the learning rate after each update.""" 50 | return self.optimizer.get_lr() 51 | 52 | 53 | class LegacyFairseqLRScheduler(FairseqLRScheduler): 54 | def __init__(self, args: Namespace, optimizer): 55 | if not isinstance(optimizer, FairseqOptimizer): 56 | raise ValueError("optimizer must be an instance of FairseqOptimizer") 57 | self.args = args 58 | self.optimizer = optimizer 59 | self.best = None 60 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/pass_through.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | 8 | from fairseq.dataclass import FairseqDataclass 9 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler 10 | 11 | 12 | @dataclass 13 | class PassThroughScheduleConfig(FairseqDataclass): 14 | pass 15 | 16 | 17 | @register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig) 18 | class PassThroughScheduleSchedule(FairseqLRScheduler): 19 | """Delegate lr scheduling to the optimizer.""" 20 | 21 | def __init__(self, cfg: PassThroughScheduleConfig, optimizer): 22 | super().__init__(cfg, optimizer) 23 | assert ( 24 | hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None 25 | ), "Pass-through schedule can only be used with optimizers with their own schedulers" 26 | 27 | def state_dict(self): 28 | return self.optimizer.lr_scheduler.state_dict() 29 | 30 | def load_state_dict(self, state_dict): 31 | self.optimizer.lr_scheduler.load_state_dict(state_dict) 32 | 33 | def step_begin_epoch(self, epoch): 34 | """Update the learning rate at the beginning of the given epoch.""" 35 | return self.optimizer.lr_scheduler.step_begin_epoch(epoch) 36 | 37 | def step_update(self, num_updates): 38 | """Update the learning rate after each update.""" 39 | return self.optimizer.lr_scheduler.step_update(num_updates) 40 | -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import LegacyFairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer("sgd") 12 | class SGD(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M', 22 | help='momentum factor') 23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 24 | help='weight decay') 25 | # fmt: on 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | "lr": self.args.lr[0], 37 | "momentum": self.args.momentum, 38 | "weight_decay": self.args.weight_decay, 39 | } 40 | 41 | @property 42 | def supports_flat_params(self): 43 | return True 44 | -------------------------------------------------------------------------------- /fairseq/optim/shard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Dict 7 | 8 | from fairseq.distributed import utils 9 | 10 | 11 | try: 12 | from fairscale.optim import OSS 13 | 14 | _has_fairscale = True 15 | except ImportError: 16 | _has_fairscale = False 17 | 18 | 19 | def shard_(optimizer, group): 20 | if not _has_fairscale: 21 | raise ImportError( 22 | "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" 23 | ) 24 | 25 | class FairseqOSS(OSS): 26 | @property 27 | def disable_mem_eff_fp16_loading_hack(self): 28 | return True 29 | 30 | def __getattr__(self, name): 31 | if name.startswith("supports") and hasattr(self.optim, name): 32 | return getattr(self.optim, name) 33 | raise AttributeError( 34 | "'FairseqOSS' object has no attribute {0!r}".format(name) 35 | ) 36 | 37 | def broadcast_global_state_dict( 38 | self, state_dict: Dict[str, Any] 39 | ) -> Dict[str, Any]: 40 | """ 41 | Broadcasts the entire state_dict to all other ranks 42 | each rank is responsible to load their own partition of data 43 | """ 44 | return utils.broadcast_object( 45 | state_dict, 46 | src_rank=0, 47 | group=self.group, 48 | ) 49 | 50 | torch_optimizer = optimizer.optimizer 51 | optim_cls = type(torch_optimizer) 52 | 53 | optimizer.optimizer = FairseqOSS( 54 | torch_optimizer.param_groups, 55 | optim_cls, 56 | group=group, 57 | **optimizer.optimizer_config 58 | ) 59 | -------------------------------------------------------------------------------- /fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import pdb 9 | import sys 10 | 11 | 12 | __all__ = ["set_trace"] 13 | 14 | 15 | _stdin = [None] 16 | _stdin_lock = multiprocessing.Lock() 17 | try: 18 | _stdin_fd = sys.stdin.fileno() 19 | except Exception: 20 | _stdin_fd = None 21 | 22 | 23 | class MultiprocessingPdb(pdb.Pdb): 24 | """A Pdb wrapper that works in a multiprocessing environment. 25 | 26 | Usage: `from fairseq import pdb; pdb.set_trace()` 27 | """ 28 | 29 | def __init__(self): 30 | pdb.Pdb.__init__(self, nosigint=True) 31 | 32 | def _cmdloop(self): 33 | stdin_bak = sys.stdin 34 | with _stdin_lock: 35 | try: 36 | if _stdin_fd is not None: 37 | if not _stdin[0]: 38 | _stdin[0] = os.fdopen(_stdin_fd) 39 | sys.stdin = _stdin[0] 40 | self.cmdloop() 41 | finally: 42 | sys.stdin = stdin_bak 43 | 44 | 45 | def set_trace(): 46 | pdb = MultiprocessingPdb() 47 | pdb.set_trace(sys._getframe().f_back) 48 | -------------------------------------------------------------------------------- /fairseq/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | from abc import ABC, abstractmethod 10 | 11 | from fairseq import registry 12 | from omegaconf import DictConfig 13 | 14 | 15 | class BaseScorer(ABC): 16 | def __init__(self, cfg): 17 | self.cfg = cfg 18 | self.ref = [] 19 | self.pred = [] 20 | 21 | def add_string(self, ref, pred): 22 | self.ref.append(ref) 23 | self.pred.append(pred) 24 | 25 | @abstractmethod 26 | def score(self) -> float: 27 | pass 28 | 29 | @abstractmethod 30 | def result_string(self) -> str: 31 | pass 32 | 33 | 34 | _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( 35 | "--scoring", default="bleu" 36 | ) 37 | 38 | 39 | def build_scorer(choice, tgt_dict): 40 | _choice = choice._name if isinstance(choice, DictConfig) else choice 41 | 42 | if _choice == "bleu": 43 | from fairseq.scoring import bleu 44 | 45 | return bleu.Scorer( 46 | bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) 47 | ) 48 | return _build_scorer(choice) 49 | 50 | 51 | # automatically import any Python files in the current directory 52 | for file in sorted(os.listdir(os.path.dirname(__file__))): 53 | if file.endswith(".py") and not file.startswith("_"): 54 | module = file[: file.find(".py")] 55 | importlib.import_module("fairseq.scoring." + module) 56 | -------------------------------------------------------------------------------- /fairseq/scoring/bertscore.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | import numpy as np 9 | 10 | from fairseq.dataclass import FairseqDataclass 11 | from fairseq.scoring import BaseScorer, register_scorer 12 | 13 | 14 | @dataclass 15 | class BertScoreScorerConfig(FairseqDataclass): 16 | bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"}) 17 | 18 | 19 | @register_scorer("bert_score", dataclass=BertScoreScorerConfig) 20 | class BertScoreScorer(BaseScorer): 21 | def __init__(self, cfg): 22 | super(BertScoreScorer, self).__init__(cfg) 23 | try: 24 | import bert_score as _bert_score 25 | except ImportError: 26 | raise ImportError("Please install BERTScore: pip install bert-score") 27 | 28 | self.cfg = cfg 29 | self._bert_score = _bert_score 30 | self.scores = None 31 | 32 | def add_string(self, ref, pred): 33 | self.ref.append(ref) 34 | self.pred.append(pred) 35 | 36 | def score(self, order=4): 37 | _, _, self.scores = self._bert_score.score( 38 | self.pred, self.ref, lang=self.cfg.bert_score_lang 39 | ) 40 | self.scores = self.scores.numpy() 41 | return np.mean(self.scores) 42 | 43 | def result_string(self, order=4): 44 | return f"BERTScore: {self.score():.4f}" 45 | -------------------------------------------------------------------------------- /fairseq/scoring/chrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from dataclasses import dataclass 8 | 9 | from fairseq.dataclass import FairseqDataclass 10 | from fairseq.scoring import BaseScorer, register_scorer 11 | 12 | 13 | @dataclass 14 | class ChrFScorerConfig(FairseqDataclass): 15 | pass 16 | 17 | 18 | @register_scorer("chrf", dataclass=ChrFScorerConfig) 19 | class ChrFScorer(BaseScorer): 20 | def __init__(self, args): 21 | super(ChrFScorer, self).__init__(args) 22 | import sacrebleu 23 | 24 | self.sacrebleu = sacrebleu 25 | 26 | def add_string(self, ref, pred): 27 | self.ref.append(ref) 28 | self.pred.append(pred) 29 | 30 | def score(self, order=4): 31 | return self.result_string(order).score 32 | 33 | def result_string(self, order=4): 34 | if order != 4: 35 | raise NotImplementedError 36 | return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format() 37 | -------------------------------------------------------------------------------- /fairseq/scoring/meteor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | from dataclasses import dataclass 8 | 9 | from fairseq.dataclass import FairseqDataclass 10 | from fairseq.scoring import BaseScorer, register_scorer 11 | 12 | 13 | @dataclass 14 | class MeteorScorerConfig(FairseqDataclass): 15 | pass 16 | 17 | 18 | @register_scorer("meteor", dataclass=MeteorScorerConfig) 19 | class MeteorScorer(BaseScorer): 20 | def __init__(self, args): 21 | super(MeteorScorer, self).__init__(args) 22 | try: 23 | import nltk 24 | except ImportError: 25 | raise ImportError("Please install nltk to use METEOR scorer") 26 | 27 | self.nltk = nltk 28 | self.scores = [] 29 | 30 | def add_string(self, ref, pred): 31 | self.ref.append(ref) 32 | self.pred.append(pred) 33 | 34 | def score(self, order=4): 35 | self.scores = [ 36 | self.nltk.translate.meteor_score.single_meteor_score(r, p) 37 | for r, p in zip(self.ref, self.pred) 38 | ] 39 | return np.mean(self.scores) 40 | 41 | def result_string(self, order=4): 42 | return f"METEOR: {self.score():.4f}" 43 | -------------------------------------------------------------------------------- /fairseq/scoring/wer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from fairseq.dataclass import FairseqDataclass 9 | from fairseq.scoring import BaseScorer, register_scorer 10 | from fairseq.scoring.tokenizer import EvaluationTokenizer 11 | 12 | 13 | @dataclass 14 | class WerScorerConfig(FairseqDataclass): 15 | wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( 16 | default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} 17 | ) 18 | wer_remove_punct: bool = field( 19 | default=False, metadata={"help": "remove punctuation"} 20 | ) 21 | wer_char_level: bool = field( 22 | default=False, metadata={"help": "evaluate at character level"} 23 | ) 24 | wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) 25 | 26 | 27 | @register_scorer("wer", dataclass=WerScorerConfig) 28 | class WerScorer(BaseScorer): 29 | def __init__(self, cfg): 30 | super().__init__(cfg) 31 | self.reset() 32 | try: 33 | import editdistance as ed 34 | except ImportError: 35 | raise ImportError("Please install editdistance to use WER scorer") 36 | self.ed = ed 37 | self.tokenizer = EvaluationTokenizer( 38 | tokenizer_type=self.cfg.wer_tokenizer, 39 | lowercase=self.cfg.wer_lowercase, 40 | punctuation_removal=self.cfg.wer_remove_punct, 41 | character_tokenization=self.cfg.wer_char_level, 42 | ) 43 | 44 | def reset(self): 45 | self.distance = 0 46 | self.ref_length = 0 47 | 48 | def add_string(self, ref, pred): 49 | ref_items = self.tokenizer.tokenize(ref).split() 50 | pred_items = self.tokenizer.tokenize(pred).split() 51 | self.distance += self.ed.eval(ref_items, pred_items) 52 | self.ref_length += len(ref_items) 53 | 54 | def result_string(self): 55 | return f"WER: {self.score():.2f}" 56 | 57 | def score(self): 58 | return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 59 | -------------------------------------------------------------------------------- /fairseq/tasks/frm_text_to_speech.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator 9 | from fairseq.tasks import register_task 10 | from fairseq.tasks.text_to_speech import TextToSpeechTask 11 | 12 | 13 | logging.basicConfig( 14 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 15 | datefmt="%Y-%m-%d %H:%M:%S", 16 | level=logging.INFO, 17 | ) 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @register_task("frm_text_to_speech") 22 | class FrmTextToSpeechTask(TextToSpeechTask): 23 | @staticmethod 24 | def add_args(parser): 25 | TextToSpeechTask.add_args(parser) 26 | parser.add_argument("--do_chunk", action="store_true", help="train on chunks") 27 | parser.add_argument("--chunk_bound", default=-1, type=int) 28 | parser.add_argument("--chunk_init", default=50, type=int) 29 | parser.add_argument("--chunk_incr", default=5, type=int) 30 | parser.add_argument("--add_eos", action="store_true") 31 | parser.add_argument("--dedup", action="store_true") 32 | parser.add_argument("--ref_fpu", default=-1, type=float) 33 | 34 | def load_dataset(self, split, **unused_kwargs): 35 | is_train_split = split.startswith("train") 36 | pre_tokenizer = self.build_tokenizer(self.args) 37 | bpe_tokenizer = self.build_bpe(self.args) 38 | self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv( 39 | self.args.data, 40 | self.data_cfg, 41 | split, 42 | self.src_dict, 43 | pre_tokenizer, 44 | bpe_tokenizer, 45 | is_train_split=is_train_split, 46 | n_frames_per_step=self.args.n_frames_per_step, 47 | speaker_to_id=self.speaker_to_id, 48 | do_chunk=self.args.do_chunk, 49 | chunk_bound=self.args.chunk_bound, 50 | chunk_init=self.args.chunk_init, 51 | chunk_incr=self.args.chunk_incr, 52 | add_eos=self.args.add_eos, 53 | dedup=self.args.dedup, 54 | ref_fpu=self.args.ref_fpu, 55 | ) 56 | -------------------------------------------------------------------------------- /fairseq/tasks/simultaneous_translation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from fairseq.tasks import register_task 8 | from fairseq.tasks.speech_to_text import SpeechToTextTask 9 | from fairseq.tasks.translation import TranslationTask, TranslationConfig 10 | 11 | try: 12 | import examples.simultaneous_translation # noqa 13 | 14 | import_successful = True 15 | except BaseException: 16 | import_successful = False 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def check_import(flag): 23 | if not flag: 24 | raise ImportError( 25 | "'examples.simultaneous_translation' is not correctly imported. " 26 | "Please considering `pip install -e $FAIRSEQ_DIR`." 27 | ) 28 | 29 | 30 | @register_task("simul_speech_to_text") 31 | class SimulSpeechToTextTask(SpeechToTextTask): 32 | def __init__(self, args, tgt_dict): 33 | check_import(import_successful) 34 | super().__init__(args, tgt_dict) 35 | 36 | 37 | @register_task("simul_text_to_text", dataclass=TranslationConfig) 38 | class SimulTextToTextTask(TranslationTask): 39 | def __init__(self, cfg, src_dict, tgt_dict): 40 | check_import(import_successful) 41 | super().__init__(cfg, src_dict, tgt_dict) 42 | -------------------------------------------------------------------------------- /fairseq/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary 8 | from fairseq.tasks.translation import TranslationConfig, TranslationTask 9 | 10 | from . import register_task 11 | 12 | 13 | @dataclass 14 | class TranslationFromPretrainedXLMConfig(TranslationConfig): 15 | pass 16 | 17 | 18 | @register_task( 19 | "translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig 20 | ) 21 | class TranslationFromPretrainedXLMTask(TranslationTask): 22 | """ 23 | Same as TranslationTask except use the MaskedLMDictionary class so that 24 | we can load data that was binarized with the MaskedLMDictionary class. 25 | 26 | This task should be used for the entire training pipeline when we want to 27 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 28 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 29 | of that trained model. 30 | """ 31 | 32 | @classmethod 33 | def load_dictionary(cls, filename): 34 | """Load the masked LM dictionary from the filename 35 | 36 | Args: 37 | filename (str): the filename 38 | """ 39 | return MaskedLMDictionary.load(filename) 40 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | 9 | SPACE_NORMALIZER = re.compile(r"\s+") 10 | 11 | 12 | def tokenize_line(line): 13 | line = SPACE_NORMALIZER.sub(" ", line) 14 | line = line.strip() 15 | return line.split() 16 | -------------------------------------------------------------------------------- /fairseq/version.txt: -------------------------------------------------------------------------------- 1 | 1.0.0a0 2 | -------------------------------------------------------------------------------- /fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FA-DAT/3599a3023c89ef16d2b33bf7c3a2b6f85aedaa4b/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /fs_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * 4 | from .optimizer import * 5 | 6 | print("fairseq plugins loaded...") -------------------------------------------------------------------------------- /fs_plugins/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.criterions." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/criterions/utilities.py: -------------------------------------------------------------------------------- 1 | ########################################################################## 2 | # Copyright (C) 2022 COAI @ Tsinghua University 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | ########################################################################## 16 | 17 | def parse_anneal_argument(anneal_str): 18 | def parse_value_pos(value_str): 19 | if "@" in value_str: 20 | value, pos = value_str.split("@") 21 | else: 22 | value = value_str 23 | pos = "0" 24 | return float(value), float(pos.replace("k", "000")) 25 | 26 | res = [] 27 | for value_str in anneal_str.split(":"): 28 | res.append(parse_value_pos(value_str)) 29 | return res 30 | 31 | def get_anneal_value(anneal_params, update_num): 32 | last_value, last_pos = anneal_params[0][0], 0 33 | for value, pos in anneal_params: 34 | if update_num < pos: 35 | return last_value + (value - last_value) * (update_num - last_pos) / (pos - last_pos + 1) 36 | last_value, last_pos = value, pos 37 | return anneal_params[-1][0] 38 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .dag_loss import dag_loss, dag_best_alignment, dag_logsoftmax_gather_inplace, torch_dag_loss, torch_dag_best_alignment, torch_dag_logsoftmax_gather_inplace -------------------------------------------------------------------------------- /fs_plugins/custom_ops/dag_loss.cpp: -------------------------------------------------------------------------------- 1 | // ########################################################################## 2 | // Copyright (C) 2022 COAI @ Tsinghua University 3 | 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ########################################################################### 16 | 17 | #include 18 | 19 | std::tuple dag_loss(const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, bool require_gradient, int config); 20 | std::tuple dag_loss_backward(const torch::Tensor &grad_output, const torch::Tensor &alpha, const torch::Tensor &beta, const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, int config1, int config2); 21 | std::tuple dag_best_alignment(const torch::Tensor &match_all, const torch::Tensor &links, const torch::Tensor &output_length, const torch::Tensor &target_length, int config); 22 | torch::Tensor logsoftmax_gather(torch::Tensor word_ins_out, const torch::Tensor &select_idx, bool require_gradient); 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("dag_loss", &dag_loss, "DAG Loss"); 26 | m.def("dag_loss_backward", &dag_loss_backward, "DAG Loss Backward"); 27 | m.def("dag_best_alignment", &dag_best_alignment, "DAG Best Alignment"); 28 | m.def("logsoftmax_gather", &logsoftmax_gather, "logsoftmax + gather"); 29 | } 30 | -------------------------------------------------------------------------------- /fs_plugins/custom_ops/utilities.h: -------------------------------------------------------------------------------- 1 | #define GCC_VERSION (__GNUC__ * 10000 \ 2 | + __GNUC_MINOR__ * 100 \ 3 | + __GNUC_PATCHLEVEL__) 4 | 5 | #if GCC_VERSION >= 70000 6 | #define if_constexpr(expression) if constexpr (expression) 7 | #else 8 | #define if_constexpr(expression) if(expression) 9 | #endif 10 | -------------------------------------------------------------------------------- /fs_plugins/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.models." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.optimizer." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.tasks." + file_name) 9 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import functools 8 | import importlib 9 | 10 | 11 | dependencies = [ 12 | "dataclasses", 13 | "hydra", 14 | "numpy", 15 | "omegaconf", 16 | "regex", 17 | "requests", 18 | "torch", 19 | ] 20 | 21 | 22 | # Check for required dependencies and raise a RuntimeError if any are missing. 23 | missing_deps = [] 24 | for dep in dependencies: 25 | try: 26 | importlib.import_module(dep) 27 | except ImportError: 28 | # Hack: the hydra package is provided under the "hydra-core" name in 29 | # pypi. We don't want the user mistakenly calling `pip install hydra` 30 | # since that will install an unrelated package. 31 | if dep == "hydra": 32 | dep = "hydra-core" 33 | missing_deps.append(dep) 34 | if len(missing_deps) > 0: 35 | raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) 36 | 37 | 38 | # only do fairseq imports after checking for dependencies 39 | from fairseq.hub_utils import ( # noqa; noqa 40 | BPEHubInterface as bpe, 41 | TokenizerHubInterface as tokenizer, 42 | ) 43 | from fairseq.models import MODEL_REGISTRY # noqa 44 | 45 | 46 | # torch.hub doesn't build Cython components, so if they are not found then try 47 | # to build them here 48 | try: 49 | import fairseq.data.token_block_utils_fast # noqa 50 | except ImportError: 51 | try: 52 | import cython # noqa 53 | import os 54 | from setuptools import sandbox 55 | 56 | sandbox.run_setup( 57 | os.path.join(os.path.dirname(__file__), "setup.py"), 58 | ["build_ext", "--inplace"], 59 | ) 60 | except ImportError: 61 | print( 62 | "Unable to build Cython components. Please make sure Cython is " 63 | "installed if the torch.hub model you are loading depends on it." 64 | ) 65 | 66 | 67 | # automatically expose models defined in FairseqModel::hub_models 68 | for _model_type, _cls in MODEL_REGISTRY.items(): 69 | for model_name in _cls.hub_models().keys(): 70 | globals()[model_name] = functools.partial( 71 | _cls.from_pretrained, 72 | model_name, 73 | ) 74 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FA-DAT/3599a3023c89ef16d2b33bf7c3a2b6f85aedaa4b/model.png -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/FA-DAT/3599a3023c89ef16d2b33bf7c3a2b6f85aedaa4b/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compare_namespaces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script to compare two argparse.Namespace objects.""" 3 | 4 | from argparse import Namespace # noqa 5 | 6 | 7 | def main(): 8 | 9 | ns1 = eval(input("Namespace 1: ")) 10 | ns2 = eval(input("Namespace 2: ")) 11 | 12 | def keys(ns): 13 | ks = set() 14 | for k in dir(ns): 15 | if not k.startswith("_"): 16 | ks.add(k) 17 | return ks 18 | 19 | k1 = keys(ns1) 20 | k2 = keys(ns2) 21 | 22 | def print_keys(ks, ns1, ns2=None): 23 | for k in ks: 24 | if ns2 is None: 25 | print("{}\t{}".format(k, getattr(ns1, k, None))) 26 | else: 27 | print( 28 | "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None)) 29 | ) 30 | 31 | print("Keys unique to namespace 1:") 32 | print_keys(k1 - k2, ns1) 33 | print() 34 | 35 | print("Keys unique to namespace 2:") 36 | print_keys(k2 - k1, ns2) 37 | print() 38 | 39 | print("Overlapping keys with different values:") 40 | ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")] 41 | print_keys(ks, ns1, ns2) 42 | print() 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /scripts/constraints/validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import sys 9 | 10 | 11 | """Reads in a fairseq output file, and verifies that the constraints 12 | (C- lines) are present in the output (the first H- line). Assumes that 13 | constraints are listed prior to the first hypothesis. 14 | """ 15 | 16 | constraints = [] 17 | found = 0 18 | total = 0 19 | for line in sys.stdin: 20 | if line.startswith("C-"): 21 | constraints.append(line.rstrip().split("\t")[1]) 22 | elif line.startswith("H-"): 23 | text = line.split("\t")[2] 24 | 25 | for constraint in constraints: 26 | total += 1 27 | if constraint in text: 28 | found += 1 29 | else: 30 | print(f"No {constraint} in {text}", file=sys.stderr) 31 | 32 | constraints = [] 33 | 34 | print(f"Found {found} / {total} = {100 * found / total:.1f}%") 35 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) Facebook, Inc. and its affiliates. 2 | -- 3 | -- This source code is licensed under the MIT license found in the 4 | -- LICENSE file in the root directory of this source tree. 5 | -- 6 | -- Usage: convert_dictionary.lua 7 | require 'fairseq' 8 | require 'torch' 9 | require 'paths' 10 | 11 | if #arg < 1 then 12 | print('usage: convert_dictionary.lua ') 13 | os.exit(1) 14 | end 15 | if not paths.filep(arg[1]) then 16 | print('error: file does not exit: ' .. arg[1]) 17 | os.exit(1) 18 | end 19 | 20 | dict = torch.load(arg[1]) 21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 22 | assert(dst:match('.txt$')) 23 | 24 | f = io.open(dst, 'w') 25 | for idx, symbol in ipairs(dict.index_to_symbol) do 26 | if idx > dict.cutoff then 27 | break 28 | end 29 | f:write(symbol) 30 | f:write(' ') 31 | f:write(dict.index_to_freq[idx]) 32 | f:write('\n') 33 | end 34 | f:close() 35 | -------------------------------------------------------------------------------- /scripts/count_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Count the number of documents and average number of lines and tokens per 8 | document in a large file. Documents should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import gzip 13 | import sys 14 | 15 | import numpy as np 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("input") 21 | parser.add_argument("--gzip", action="store_true") 22 | args = parser.parse_args() 23 | 24 | def gopen(): 25 | if args.gzip: 26 | return gzip.open(args.input, "r") 27 | else: 28 | return open(args.input, "r", encoding="utf-8") 29 | 30 | num_lines = [] 31 | num_toks = [] 32 | with gopen() as h: 33 | num_docs = 1 34 | num_lines_in_doc = 0 35 | num_toks_in_doc = 0 36 | for i, line in enumerate(h): 37 | if len(line.strip()) == 0: # empty line indicates new document 38 | num_docs += 1 39 | num_lines.append(num_lines_in_doc) 40 | num_toks.append(num_toks_in_doc) 41 | num_lines_in_doc = 0 42 | num_toks_in_doc = 0 43 | else: 44 | num_lines_in_doc += 1 45 | num_toks_in_doc += len(line.rstrip().split()) 46 | if i % 1000000 == 0: 47 | print(i, file=sys.stderr, end="", flush=True) 48 | elif i % 100000 == 0: 49 | print(".", file=sys.stderr, end="", flush=True) 50 | print(file=sys.stderr, flush=True) 51 | 52 | print("found {} docs".format(num_docs)) 53 | print("average num lines per doc: {}".format(np.mean(num_lines))) 54 | print("average num toks per doc: {}".format(np.mean(num_toks))) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | from fairseq.data import Dictionary, data_utils, indexed_dataset 10 | 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser( 14 | description="writes text from binarized file to stdout" 15 | ) 16 | # fmt: off 17 | parser.add_argument('--dataset-impl', help='dataset implementation', 18 | choices=indexed_dataset.get_available_dataset_impl()) 19 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) 20 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 21 | # fmt: on 22 | 23 | return parser 24 | 25 | 26 | def main(): 27 | parser = get_parser() 28 | args = parser.parse_args() 29 | 30 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None 31 | dataset = data_utils.load_indexed_dataset( 32 | args.input, 33 | dictionary, 34 | dataset_impl=args.dataset_impl, 35 | default="lazy", 36 | ) 37 | 38 | for tensor_line in dataset: 39 | if dictionary is None: 40 | line = " ".join([str(int(x)) for x in tensor_line]) 41 | else: 42 | line = dictionary.string(tensor_line) 43 | 44 | print(line) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /scripts/sacrebleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | if ! command -v sacremoses &> /dev/null 15 | then 16 | echo "sacremoses could not be found, please install with: pip install sacremoses" 17 | exit 18 | fi 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | sacremoses detokenize \ 25 | > $GEN.sorted.detok 26 | 27 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 28 | -------------------------------------------------------------------------------- /scripts/shard_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Split a large file into shards while respecting document boundaries. Documents 8 | should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import contextlib 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("input") 18 | parser.add_argument("--num-shards", type=int) 19 | args = parser.parse_args() 20 | 21 | assert args.num_shards is not None and args.num_shards > 1 22 | 23 | with open(args.input, "r", encoding="utf-8") as h: 24 | with contextlib.ExitStack() as stack: 25 | outputs = [ 26 | stack.enter_context( 27 | open(args.input + ".shard" + str(i), "w", encoding="utf-8") 28 | ) 29 | for i in range(args.num_shards) 30 | ] 31 | 32 | doc = [] 33 | first_doc = [True] * args.num_shards 34 | 35 | def output_doc(i): 36 | if not first_doc[i]: 37 | outputs[i].write("\n") 38 | first_doc[i] = False 39 | for line in doc: 40 | outputs[i].write(line) 41 | doc.clear() 42 | 43 | num_docs = 0 44 | for line in h: 45 | if line.strip() == "": # empty line indicates new document 46 | output_doc(num_docs % args.num_shards) 47 | num_docs += 1 48 | else: 49 | doc.append(line) 50 | output_doc(num_docs % args.num_shards) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--model", required=True, help="sentencepiece model to use for decoding" 19 | ) 20 | parser.add_argument("--input", required=True, help="input file to decode") 21 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 22 | args = parser.parse_args() 23 | 24 | sp = spm.SentencePieceProcessor() 25 | sp.Load(args.model) 26 | 27 | if args.input_format == "piece": 28 | 29 | def decode(input): 30 | return "".join(sp.DecodePieces(input)) 31 | 32 | elif args.input_format == "id": 33 | 34 | def decode(input): 35 | return "".join(sp.DecodeIds(input)) 36 | 37 | else: 38 | raise NotImplementedError 39 | 40 | def tok2int(tok): 41 | # remap reference-side (represented as <>) to 0 42 | return int(tok) if tok != "<>" else 0 43 | 44 | with open(args.input, "r", encoding="utf-8") as h: 45 | for line in h: 46 | if args.input_format == "id": 47 | print(decode(list(map(tok2int, line.rstrip().split())))) 48 | elif args.input_format == "piece": 49 | print(decode(line.rstrip().split())) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import sys 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | if __name__ == "__main__": 16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) 17 | -------------------------------------------------------------------------------- /scripts/test_fsdp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf fsdp_dummy 3 | mkdir -p fsdp_dummy 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ 5 | --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ 6 | --cpu-offload --checkpoint-activations \ 7 | --task language_modeling --tokens-per-sample 256 --batch-size 8 \ 8 | --arch transformer_lm_gpt2_tiny \ 9 | --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ 10 | --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ 11 | --max-update 5 --log-format json --log-interval 1 \ 12 | --save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \ 13 | --restore-file x.pt "$@" 14 | 15 | # Now we try to load the checkpoint 16 | CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \ 17 | --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ 18 | --cpu-offload --checkpoint-activations \ 19 | --task language_modeling --tokens-per-sample 256 --batch-size 8 \ 20 | --arch transformer_lm_gpt2_tiny \ 21 | --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ 22 | --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ 23 | --max-update 2 --log-format json --log-interval 1 \ 24 | --save-interval-updates 2 --save-dir fsdp_dummy 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 127 3 | extend-ignore = E203, W503 4 | extend-exclude = fairseq/model_parallel/megatron 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. 8 | """ 9 | 10 | from fairseq_cli.train import cli_main 11 | 12 | 13 | if __name__ == "__main__": 14 | cli_main() 15 | --------------------------------------------------------------------------------