├── .DS_Store
├── README.md
├── arch.png
├── convert_fairseq_to_huggingface.py
├── docs
├── Makefile
├── _static
│ └── theme_overrides.css
├── command_line_tools.rst
├── conf.py
├── criterions.rst
├── data.rst
├── docutils.conf
├── getting_started.rst
├── index.rst
├── lr_scheduler.rst
├── make.bat
├── models.rst
├── modules.rst
├── optim.rst
├── overview.rst
├── requirements.txt
├── tasks.rst
├── tutorial_classifying_names.rst
└── tutorial_simple_lstm.rst
├── eval_lm.py
├── examples
├── .DS_Store
├── .gitignore
├── __init__.py
└── roberta
│ ├── .DS_Store
│ ├── commonsense_qa
│ ├── README.md
│ ├── __init__.py
│ ├── commonsense_qa_task.py
│ └── download_cqa_data.sh
│ ├── multiprocessing_bpe_encoder.py
│ ├── train_base_to_base_plus.sh
│ └── wsc
│ ├── README.md
│ ├── __init__.py
│ ├── wsc_criterion.py
│ ├── wsc_task.py
│ └── wsc_utils.py
├── fairseq.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
├── entry_points.txt
├── not-zip-safe
├── requires.txt
└── top_level.txt
├── fairseq
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── binarizer.cpython-37.pyc
│ ├── checkpoint_utils.cpython-37.pyc
│ ├── distributed_utils.cpython-37.pyc
│ ├── file_utils.cpython-37.pyc
│ ├── hub_utils.cpython-37.pyc
│ ├── iterative_refinement_generator.cpython-37.pyc
│ ├── legacy_distributed_data_parallel.cpython-37.pyc
│ ├── meters.cpython-37.pyc
│ ├── options.cpython-37.pyc
│ ├── pdb.cpython-37.pyc
│ ├── progress_bar.cpython-37.pyc
│ ├── registry.cpython-37.pyc
│ ├── search.cpython-37.pyc
│ ├── sequence_generator.cpython-37.pyc
│ ├── tokenizer.cpython-37.pyc
│ ├── trainer.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── binarizer.py
├── bleu.py
├── checkpoint_utils.py
├── clib
│ ├── libbleu
│ │ ├── libbleu.cpp
│ │ └── module.cpp
│ └── libnat
│ │ └── edit_dist.cpp
├── criterions
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── adaptive_loss.cpython-37.pyc
│ │ ├── binary_cross_entropy.cpython-37.pyc
│ │ ├── composite_loss.cpython-37.pyc
│ │ ├── cross_entropy.cpython-37.pyc
│ │ ├── fairseq_criterion.cpython-37.pyc
│ │ ├── label_smoothed_cross_entropy.cpython-37.pyc
│ │ ├── label_smoothed_cross_entropy_with_alignment.cpython-37.pyc
│ │ ├── legacy_masked_lm.cpython-37.pyc
│ │ ├── masked_lm.cpython-37.pyc
│ │ ├── masked_lm_distil.cpython-37.pyc
│ │ ├── masked_lm_distil_H_half.cpython-37.pyc
│ │ ├── nat_loss.cpython-37.pyc
│ │ ├── sentence_prediction.cpython-37.pyc
│ │ └── sentence_ranking.cpython-37.pyc
│ ├── adaptive_loss.py
│ ├── binary_cross_entropy.py
│ ├── composite_loss.py
│ ├── cross_entropy.py
│ ├── fairseq_criterion.py
│ ├── label_smoothed_cross_entropy.py
│ ├── label_smoothed_cross_entropy_with_alignment.py
│ ├── legacy_masked_lm.py
│ ├── masked_lm.py
│ ├── masked_lm_distil.py
│ ├── masked_lm_distil_H_half.py
│ ├── nat_loss.py
│ ├── sentence_prediction.py
│ └── sentence_ranking.py
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── append_token_dataset.cpython-37.pyc
│ │ ├── backtranslation_dataset.cpython-37.pyc
│ │ ├── base_wrapper_dataset.cpython-37.pyc
│ │ ├── colorize_dataset.cpython-37.pyc
│ │ ├── concat_dataset.cpython-37.pyc
│ │ ├── concat_sentences_dataset.cpython-37.pyc
│ │ ├── data_utils.cpython-37.pyc
│ │ ├── denoising_dataset.cpython-37.pyc
│ │ ├── dictionary.cpython-37.pyc
│ │ ├── fairseq_dataset.cpython-37.pyc
│ │ ├── id_dataset.cpython-37.pyc
│ │ ├── indexed_dataset.cpython-37.pyc
│ │ ├── iterators.cpython-37.pyc
│ │ ├── language_pair_dataset.cpython-37.pyc
│ │ ├── list_dataset.cpython-37.pyc
│ │ ├── lm_context_window_dataset.cpython-37.pyc
│ │ ├── lru_cache_dataset.cpython-37.pyc
│ │ ├── mask_tokens_dataset.cpython-37.pyc
│ │ ├── monolingual_dataset.cpython-37.pyc
│ │ ├── multi_corpus_sampled_dataset.cpython-37.pyc
│ │ ├── nested_dictionary_dataset.cpython-37.pyc
│ │ ├── noising.cpython-37.pyc
│ │ ├── num_samples_dataset.cpython-37.pyc
│ │ ├── numel_dataset.cpython-37.pyc
│ │ ├── offset_tokens_dataset.cpython-37.pyc
│ │ ├── pad_dataset.cpython-37.pyc
│ │ ├── plasma_utils.cpython-37.pyc
│ │ ├── prepend_dataset.cpython-37.pyc
│ │ ├── prepend_token_dataset.cpython-37.pyc
│ │ ├── raw_label_dataset.cpython-37.pyc
│ │ ├── replace_dataset.cpython-37.pyc
│ │ ├── resampling_dataset.cpython-37.pyc
│ │ ├── roll_dataset.cpython-37.pyc
│ │ ├── round_robin_zip_datasets.cpython-37.pyc
│ │ ├── sharded_dataset.cpython-37.pyc
│ │ ├── sort_dataset.cpython-37.pyc
│ │ ├── strip_token_dataset.cpython-37.pyc
│ │ ├── subsample_dataset.cpython-37.pyc
│ │ ├── token_block_dataset.cpython-37.pyc
│ │ ├── transform_eos_dataset.cpython-37.pyc
│ │ ├── transform_eos_lang_pair_dataset.cpython-37.pyc
│ │ └── truncate_dataset.cpython-37.pyc
│ ├── append_token_dataset.py
│ ├── audio
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ └── raw_audio_dataset.cpython-37.pyc
│ │ └── raw_audio_dataset.py
│ ├── backtranslation_dataset.py
│ ├── base_wrapper_dataset.py
│ ├── colorize_dataset.py
│ ├── concat_dataset.py
│ ├── concat_sentences_dataset.py
│ ├── data_utils.py
│ ├── data_utils_fast.cpp
│ ├── data_utils_fast.cpython-37m-x86_64-linux-gnu.so
│ ├── data_utils_fast.pyx
│ ├── denoising_dataset.py
│ ├── dictionary.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── fastbpe.cpython-37.pyc
│ │ │ ├── gpt2_bpe.cpython-37.pyc
│ │ │ ├── gpt2_bpe_utils.cpython-37.pyc
│ │ │ ├── hf_bert_bpe.cpython-37.pyc
│ │ │ ├── moses_tokenizer.cpython-37.pyc
│ │ │ ├── nltk_tokenizer.cpython-37.pyc
│ │ │ ├── sentencepiece_bpe.cpython-37.pyc
│ │ │ ├── space_tokenizer.cpython-37.pyc
│ │ │ ├── subword_nmt_bpe.cpython-37.pyc
│ │ │ └── utils.cpython-37.pyc
│ │ ├── fastbpe.py
│ │ ├── gpt2_bpe.py
│ │ ├── gpt2_bpe_utils.py
│ │ ├── hf_bert_bpe.py
│ │ ├── moses_tokenizer.py
│ │ ├── nltk_tokenizer.py
│ │ ├── sentencepiece_bpe.py
│ │ ├── space_tokenizer.py
│ │ ├── subword_nmt_bpe.py
│ │ └── utils.py
│ ├── fairseq_dataset.py
│ ├── id_dataset.py
│ ├── indexed_dataset.py
│ ├── iterators.py
│ ├── language_pair_dataset.py
│ ├── legacy
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── block_pair_dataset.cpython-37.pyc
│ │ │ ├── masked_lm_dataset.cpython-37.pyc
│ │ │ └── masked_lm_dictionary.cpython-37.pyc
│ │ ├── block_pair_dataset.py
│ │ ├── masked_lm_dataset.py
│ │ └── masked_lm_dictionary.py
│ ├── list_dataset.py
│ ├── lm_context_window_dataset.py
│ ├── lru_cache_dataset.py
│ ├── mask_tokens_dataset.py
│ ├── monolingual_dataset.py
│ ├── multi_corpus_sampled_dataset.py
│ ├── nested_dictionary_dataset.py
│ ├── noising.py
│ ├── num_samples_dataset.py
│ ├── numel_dataset.py
│ ├── offset_tokens_dataset.py
│ ├── pad_dataset.py
│ ├── plasma_utils.py
│ ├── prepend_dataset.py
│ ├── prepend_token_dataset.py
│ ├── raw_label_dataset.py
│ ├── replace_dataset.py
│ ├── resampling_dataset.py
│ ├── roll_dataset.py
│ ├── round_robin_zip_datasets.py
│ ├── sharded_dataset.py
│ ├── sort_dataset.py
│ ├── strip_token_dataset.py
│ ├── subsample_dataset.py
│ ├── token_block_dataset.py
│ ├── token_block_utils_fast.cpp
│ ├── token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so
│ ├── token_block_utils_fast.pyx
│ ├── transform_eos_dataset.py
│ ├── transform_eos_lang_pair_dataset.py
│ └── truncate_dataset.py
├── distributed_utils.py
├── file_utils.py
├── hub_utils.py
├── iterative_refinement_generator.py
├── legacy_distributed_data_parallel.py
├── libbleu.cpython-37m-x86_64-linux-gnu.so
├── libnat.cpython-37m-x86_64-linux-gnu.so
├── meters.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── cmlm_transformer.cpython-37.pyc
│ │ ├── composite_encoder.cpython-37.pyc
│ │ ├── distributed_fairseq_model.cpython-37.pyc
│ │ ├── fairseq_decoder.cpython-37.pyc
│ │ ├── fairseq_encoder.cpython-37.pyc
│ │ ├── fairseq_incremental_decoder.cpython-37.pyc
│ │ ├── fairseq_model.cpython-37.pyc
│ │ ├── fconv.cpython-37.pyc
│ │ ├── fconv_lm.cpython-37.pyc
│ │ ├── fconv_self_att.cpython-37.pyc
│ │ ├── insertion_transformer.cpython-37.pyc
│ │ ├── iterative_nonautoregressive_transformer.cpython-37.pyc
│ │ ├── levenshtein_transformer.cpython-37.pyc
│ │ ├── lightconv.cpython-37.pyc
│ │ ├── lightconv_lm.cpython-37.pyc
│ │ ├── lstm.cpython-37.pyc
│ │ ├── masked_lm.cpython-37.pyc
│ │ ├── model_utils.cpython-37.pyc
│ │ ├── multilingual_transformer.cpython-37.pyc
│ │ ├── nonautoregressive_ensembles.cpython-37.pyc
│ │ ├── nonautoregressive_transformer.cpython-37.pyc
│ │ ├── transformer.cpython-37.pyc
│ │ ├── transformer_from_pretrained_xlm.cpython-37.pyc
│ │ ├── transformer_lm.cpython-37.pyc
│ │ └── wav2vec.cpython-37.pyc
│ ├── bart
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── hub_interface.cpython-37.pyc
│ │ │ └── model.cpython-37.pyc
│ │ ├── hub_interface.py
│ │ └── model.py
│ ├── cmlm_transformer.py
│ ├── composite_encoder.py
│ ├── distributed_fairseq_model.py
│ ├── fairseq_decoder.py
│ ├── fairseq_encoder.py
│ ├── fairseq_incremental_decoder.py
│ ├── fairseq_model.py
│ ├── fconv.py
│ ├── fconv_lm.py
│ ├── fconv_self_att.py
│ ├── insertion_transformer.py
│ ├── iterative_nonautoregressive_transformer.py
│ ├── levenshtein_transformer.py
│ ├── lightconv.py
│ ├── lightconv_lm.py
│ ├── lstm.py
│ ├── masked_lm.py
│ ├── model_utils.py
│ ├── multilingual_transformer.py
│ ├── nonautoregressive_ensembles.py
│ ├── nonautoregressive_transformer.py
│ ├── roberta
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── hub_interface.cpython-37.pyc
│ │ │ └── model.cpython-37.pyc
│ │ ├── alignment_utils.py
│ │ ├── hub_interface.py
│ │ └── model.py
│ ├── transformer.py
│ ├── transformer_from_pretrained_xlm.py
│ ├── transformer_lm.py
│ └── wav2vec.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── adaptive_input.cpython-37.pyc
│ │ ├── adaptive_softmax.cpython-37.pyc
│ │ ├── beamable_mm.cpython-37.pyc
│ │ ├── character_token_embedder.cpython-37.pyc
│ │ ├── conv_tbc.cpython-37.pyc
│ │ ├── downsampled_multihead_attention.cpython-37.pyc
│ │ ├── dynamic_convolution.cpython-37.pyc
│ │ ├── gelu.cpython-37.pyc
│ │ ├── grad_multiply.cpython-37.pyc
│ │ ├── highway.cpython-37.pyc
│ │ ├── layer_norm.cpython-37.pyc
│ │ ├── learned_positional_embedding.cpython-37.pyc
│ │ ├── lightweight_convolution.cpython-37.pyc
│ │ ├── linearized_convolution.cpython-37.pyc
│ │ ├── logsumexp_moe.cpython-37.pyc
│ │ ├── mean_pool_gating_network.cpython-37.pyc
│ │ ├── multihead_attention.cpython-37.pyc
│ │ ├── positional_embedding.cpython-37.pyc
│ │ ├── scalar_bias.cpython-37.pyc
│ │ ├── sinusoidal_positional_embedding.cpython-37.pyc
│ │ ├── transformer_layer.cpython-37.pyc
│ │ ├── transformer_sentence_encoder.cpython-37.pyc
│ │ ├── transformer_sentence_encoder_layer.cpython-37.pyc
│ │ ├── unfold.cpython-37.pyc
│ │ └── vggblock.cpython-37.pyc
│ ├── adaptive_input.py
│ ├── adaptive_softmax.py
│ ├── beamable_mm.py
│ ├── character_token_embedder.py
│ ├── conv_tbc.py
│ ├── cuda_utils.cu
│ ├── downsampled_multihead_attention.py
│ ├── dynamic_convolution.py
│ ├── dynamicconv_layer
│ │ ├── __init__.py
│ │ ├── cuda_function_gen.py
│ │ ├── dynamicconv_cuda.cpp
│ │ ├── dynamicconv_cuda.cuh
│ │ ├── dynamicconv_cuda_kernel.cu
│ │ ├── dynamicconv_layer.py
│ │ ├── dynamiconv_cpu.cpp
│ │ └── setup.py
│ ├── gelu.py
│ ├── grad_multiply.py
│ ├── highway.py
│ ├── layer_norm.py
│ ├── learned_positional_embedding.py
│ ├── lightconv_layer
│ │ ├── __init__.py
│ │ ├── cuda_function_gen.py
│ │ ├── lightconv_cuda.cpp
│ │ ├── lightconv_cuda.cuh
│ │ ├── lightconv_cuda_kernel.cu
│ │ ├── lightconv_layer.py
│ │ └── setup.py
│ ├── lightweight_convolution.py
│ ├── linearized_convolution.py
│ ├── logsumexp_moe.py
│ ├── mean_pool_gating_network.py
│ ├── multihead_attention.py
│ ├── positional_embedding.py
│ ├── scalar_bias.py
│ ├── sinusoidal_positional_embedding.py
│ ├── sparse_multihead_attention.py
│ ├── sparse_transformer_sentence_encoder.py
│ ├── sparse_transformer_sentence_encoder_layer.py
│ ├── transformer_layer.py
│ ├── transformer_sentence_encoder.py
│ ├── transformer_sentence_encoder_layer.py
│ ├── unfold.py
│ └── vggblock.py
├── optim
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── adadelta.cpython-37.pyc
│ │ ├── adafactor.cpython-37.pyc
│ │ ├── adagrad.cpython-37.pyc
│ │ ├── adam.cpython-37.pyc
│ │ ├── adamax.cpython-37.pyc
│ │ ├── bmuf.cpython-37.pyc
│ │ ├── fairseq_optimizer.cpython-37.pyc
│ │ ├── fp16_optimizer.cpython-37.pyc
│ │ ├── nag.cpython-37.pyc
│ │ └── sgd.cpython-37.pyc
│ ├── adadelta.py
│ ├── adafactor.py
│ ├── adagrad.py
│ ├── adam.py
│ ├── adamax.py
│ ├── bmuf.py
│ ├── fairseq_optimizer.py
│ ├── fp16_optimizer.py
│ ├── lr_scheduler
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── cosine_lr_scheduler.cpython-37.pyc
│ │ │ ├── fairseq_lr_scheduler.cpython-37.pyc
│ │ │ ├── fixed_schedule.cpython-37.pyc
│ │ │ ├── inverse_square_root_schedule.cpython-37.pyc
│ │ │ ├── polynomial_decay_schedule.cpython-37.pyc
│ │ │ ├── reduce_lr_on_plateau.cpython-37.pyc
│ │ │ ├── tri_stage_lr_scheduler.cpython-37.pyc
│ │ │ └── triangular_lr_scheduler.cpython-37.pyc
│ │ ├── cosine_lr_scheduler.py
│ │ ├── fairseq_lr_scheduler.py
│ │ ├── fixed_schedule.py
│ │ ├── inverse_square_root_schedule.py
│ │ ├── polynomial_decay_schedule.py
│ │ ├── reduce_lr_on_plateau.py
│ │ ├── tri_stage_lr_scheduler.py
│ │ └── triangular_lr_scheduler.py
│ ├── nag.py
│ └── sgd.py
├── options.py
├── pdb.py
├── progress_bar.py
├── registry.py
├── search.py
├── sequence_generator.py
├── sequence_scorer.py
├── tasks
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── audio_pretraining.cpython-37.pyc
│ │ ├── back_distil.cpython-37.pyc
│ │ ├── cross_lingual_lm.cpython-37.pyc
│ │ ├── denoising.cpython-37.pyc
│ │ ├── fairseq_task.cpython-37.pyc
│ │ ├── language_modeling.cpython-37.pyc
│ │ ├── legacy_masked_lm.cpython-37.pyc
│ │ ├── masked_lm.cpython-37.pyc
│ │ ├── multilingual_masked_lm.cpython-37.pyc
│ │ ├── multilingual_translation.cpython-37.pyc
│ │ ├── semisupervised_translation.cpython-37.pyc
│ │ ├── sentence_prediction.cpython-37.pyc
│ │ ├── sentence_ranking.cpython-37.pyc
│ │ ├── translation.cpython-37.pyc
│ │ ├── translation_from_pretrained_xlm.cpython-37.pyc
│ │ ├── translation_lev.cpython-37.pyc
│ │ └── translation_moe.cpython-37.pyc
│ ├── audio_pretraining.py
│ ├── back_distil.py
│ ├── cross_lingual_lm.py
│ ├── denoising.py
│ ├── fairseq_task.py
│ ├── language_modeling.py
│ ├── legacy_masked_lm.py
│ ├── masked_lm.py
│ ├── multilingual_masked_lm.py
│ ├── multilingual_translation.py
│ ├── semisupervised_translation.py
│ ├── sentence_prediction.py
│ ├── sentence_ranking.py
│ ├── translation.py
│ ├── translation_from_pretrained_xlm.py
│ ├── translation_lev.py
│ └── translation_moe.py
├── tokenizer.py
├── trainer.py
└── utils.py
├── fairseq_cli
├── __init__.py
├── eval_lm.py
├── generate.py
├── interactive.py
├── preprocess.py
├── score.py
├── setup.py
└── train.py
├── generate.py
├── hubconf.py
├── interactive.py
├── preprocess.py
├── score.py
├── scripts
├── __init__.py
├── average_checkpoints.py
├── build_sym_alignment.py
├── compare_namespaces.py
├── compound_split_bleu.sh
├── convert_dictionary.lua
├── convert_model.lua
├── count_docs.py
├── read_binarized.py
├── rm_pt.py
├── sacrebleu_pregen.sh
├── shard_docs.py
├── split_train_valid_docs.py
├── spm_decode.py
├── spm_encode.py
├── spm_train.py
├── wav2vec_featurize.py
└── wav2vec_manifest.py
├── setup.py
├── tests
├── __init__.py
├── speech_recognition
│ ├── __init__.py
│ ├── asr_test_base.py
│ ├── test_collaters.py
│ ├── test_cross_entropy.py
│ └── test_vggtransformer.py
├── test_average_checkpoints.py
├── test_backtranslation_dataset.py
├── test_binaries.py
├── test_bmuf.py
├── test_character_token_embedder.py
├── test_concat_dataset.py
├── test_convtbc.py
├── test_dictionary.py
├── test_iterators.py
├── test_label_smoothing.py
├── test_memory_efficient_fp16.py
├── test_multi_corpus_sampled_dataset.py
├── test_multihead_attention.py
├── test_noising.py
├── test_reproducibility.py
├── test_resampling_dataset.py
├── test_sequence_generator.py
├── test_sequence_scorer.py
├── test_sparse_multihead_attention.py
├── test_token_block_dataset.py
├── test_train.py
├── test_utils.py
└── utils.py
├── train.py
└── validate.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge-Inheritance
2 |
3 | Source code for our NAACL 2022 paper: Knowledge Inheritance for Pre-trained Language Models.
4 |
5 | The trained model parameters (in [Fairseq](https://github.com/pytorch/fairseq) format) can be downloaded from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/aab1777a161545038c01/). Please follow [ELLE](https://github.com/thunlp/ELLE) to convert the trained checkpoint from Fairseq format into Huggingface [transformers](https://github.com/huggingface/transformers) format.
6 |
7 | We also provide the pre-training data (already processed in fairseq format) we use in [google drive](https://drive.google.com/drive/folders/1l1cuN9JQUqZTM_1NFNtetfiXMKWqGTUo?usp=sharing), covering five pre-training domains (WB, News, Reviews, BIO and CS). We sample around 3400M tokens for each domain.
8 |
9 | We refer the downstream performance evaluation to the implementation of [Fairseq](https://github.com/pytorch/fairseq) (GLUE tasks) and [Don't Stop Pre-training](https://github.com/allenai/dont-stop-pretraining) (ACL-ARC / CHEMPROT). For ACL-ARC / CHEMPROT, please refer to [ELLE](https://github.com/thunlp/ELLE) for easy implementation.
10 |
11 | If you have any question, feel free to contact me by email (yujiaqin16@gmail.com).
12 |
13 | ## Installation
14 |
15 | ``` bash
16 | git clone https://github.com/pytorch/fairseq
17 | cd fairseq
18 | pip install --editable ./
19 |
20 | git clone https://github.com/NVIDIA/apex
21 | cd apex
22 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
23 | --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
24 | --global-option="--fast_multihead_attn" ./
25 | ```
26 |
27 | ## Pre-training under KI
28 |
29 | ``` bash
30 | cd examples/roberta
31 | bash train_base_to_base_plus.sh
32 | ```
33 |
34 | ## Downstream evaluation
35 |
36 | For downstream evaluation, (1) GLUE: we refer to the implementation of [Fairseq](https://github.com/pytorch/fairseq); (2) ACL-ARC & CHEMPROT: first use convert_fairseq_to_huggingface.py to convert the Fairseq format into Huggingface's [transformers](https://github.com/huggingface/transformers) format, then test the performance using the implementation of [Don't Stop Pre-training](https://github.com/allenai/dont-stop-pretraining).
37 |
--------------------------------------------------------------------------------
/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/arch.png
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = fairseq
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/_static/theme_overrides.css:
--------------------------------------------------------------------------------
1 | .wy-table-responsive table td kbd {
2 | white-space: nowrap;
3 | }
4 | .wy-table-responsive table td {
5 | white-space: normal !important;
6 | }
7 | .wy-table-responsive {
8 | overflow: visible !important;
9 | }
10 |
--------------------------------------------------------------------------------
/docs/command_line_tools.rst:
--------------------------------------------------------------------------------
1 | .. _Command-line Tools:
2 |
3 | Command-line Tools
4 | ==================
5 |
6 | Fairseq provides several command-line tools for training and evaluating models:
7 |
8 | - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
9 | - :ref:`fairseq-train`: Train a new model on one or multiple GPUs
10 | - :ref:`fairseq-generate`: Translate pre-processed data with a trained model
11 | - :ref:`fairseq-interactive`: Translate raw text with a trained model
12 | - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
13 | - :ref:`fairseq-eval-lm`: Language model evaluation
14 |
15 |
16 | .. _fairseq-preprocess:
17 |
18 | fairseq-preprocess
19 | ~~~~~~~~~~~~~~~~~~
20 | .. automodule:: preprocess
21 |
22 | .. argparse::
23 | :module: fairseq.options
24 | :func: get_preprocessing_parser
25 | :prog: fairseq-preprocess
26 |
27 |
28 | .. _fairseq-train:
29 |
30 | fairseq-train
31 | ~~~~~~~~~~~~~
32 | .. automodule:: train
33 |
34 | .. argparse::
35 | :module: fairseq.options
36 | :func: get_training_parser
37 | :prog: fairseq-train
38 |
39 |
40 | .. _fairseq-generate:
41 |
42 | fairseq-generate
43 | ~~~~~~~~~~~~~~~~
44 | .. automodule:: generate
45 |
46 | .. argparse::
47 | :module: fairseq.options
48 | :func: get_generation_parser
49 | :prog: fairseq-generate
50 |
51 |
52 | .. _fairseq-interactive:
53 |
54 | fairseq-interactive
55 | ~~~~~~~~~~~~~~~~~~~
56 | .. automodule:: interactive
57 |
58 | .. argparse::
59 | :module: fairseq.options
60 | :func: get_interactive_generation_parser
61 | :prog: fairseq-interactive
62 |
63 |
64 | .. _fairseq-score:
65 |
66 | fairseq-score
67 | ~~~~~~~~~~~~~
68 | .. automodule:: score
69 |
70 | .. argparse::
71 | :module: fairseq_cli.score
72 | :func: get_parser
73 | :prog: fairseq-score
74 |
75 |
76 | .. _fairseq-eval-lm:
77 |
78 | fairseq-eval-lm
79 | ~~~~~~~~~~~~~~~
80 | .. automodule:: eval_lm
81 |
82 | .. argparse::
83 | :module: fairseq.options
84 | :func: get_eval_lm_parser
85 | :prog: fairseq-eval-lm
86 |
--------------------------------------------------------------------------------
/docs/criterions.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _Criterions:
5 |
6 | Criterions
7 | ==========
8 |
9 | Criterions compute the loss function given the model and batch, roughly::
10 |
11 | loss = criterion(model, batch)
12 |
13 | .. automodule:: fairseq.criterions
14 | :members:
15 |
16 | .. autoclass:: fairseq.criterions.FairseqCriterion
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21 | :members:
22 | :undoc-members:
23 | .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24 | :members:
25 | :undoc-members:
26 | .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27 | :members:
28 | :undoc-members:
29 | .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30 | :members:
31 | :undoc-members:
32 |
--------------------------------------------------------------------------------
/docs/data.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.data
5 |
6 | Data Loading and Utilities
7 | ==========================
8 |
9 | .. _datasets:
10 |
11 | Datasets
12 | --------
13 |
14 | **Datasets** define the data format and provide helpers for creating
15 | mini-batches.
16 |
17 | .. autoclass:: fairseq.data.FairseqDataset
18 | :members:
19 | .. autoclass:: fairseq.data.LanguagePairDataset
20 | :members:
21 | .. autoclass:: fairseq.data.MonolingualDataset
22 | :members:
23 |
24 | **Helper Datasets**
25 |
26 | These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
27 | provide additional functionality:
28 |
29 | .. autoclass:: fairseq.data.BacktranslationDataset
30 | :members:
31 | .. autoclass:: fairseq.data.ConcatDataset
32 | :members:
33 | .. autoclass:: fairseq.data.ResamplingDataset
34 | :members:
35 | .. autoclass:: fairseq.data.RoundRobinZipDatasets
36 | :members:
37 | .. autoclass:: fairseq.data.TransformEosDataset
38 | :members:
39 |
40 |
41 | Dictionary
42 | ----------
43 |
44 | .. autoclass:: fairseq.data.Dictionary
45 | :members:
46 |
47 |
48 | Iterators
49 | ---------
50 |
51 | .. autoclass:: fairseq.data.CountingIterator
52 | :members:
53 | .. autoclass:: fairseq.data.EpochBatchIterator
54 | :members:
55 | .. autoclass:: fairseq.data.GroupedIterator
56 | :members:
57 | .. autoclass:: fairseq.data.ShardedIterator
58 | :members:
59 |
--------------------------------------------------------------------------------
/docs/docutils.conf:
--------------------------------------------------------------------------------
1 | [writers]
2 | option-limit=0
3 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. fairseq documentation master file, created by
2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | :github_url: https://github.com/pytorch/fairseq
7 |
8 |
9 | fairseq documentation
10 | =====================
11 |
12 | Fairseq is a sequence modeling toolkit written in `PyTorch
13 | `_ that allows researchers and developers to
14 | train custom models for translation, summarization, language modeling and other
15 | text generation tasks.
16 |
17 | .. toctree::
18 | :maxdepth: 1
19 | :caption: Getting Started
20 |
21 | getting_started
22 | command_line_tools
23 |
24 | .. toctree::
25 | :maxdepth: 1
26 | :caption: Extending Fairseq
27 |
28 | overview
29 | tutorial_simple_lstm
30 | tutorial_classifying_names
31 |
32 | .. toctree::
33 | :maxdepth: 2
34 | :caption: Library Reference
35 |
36 | tasks
37 | models
38 | criterions
39 | optim
40 | lr_scheduler
41 | data
42 | modules
43 |
44 |
45 | Indices and tables
46 | ==================
47 |
48 | * :ref:`genindex`
49 | * :ref:`search`
50 |
--------------------------------------------------------------------------------
/docs/lr_scheduler.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _Learning Rate Schedulers:
5 |
6 | Learning Rate Schedulers
7 | ========================
8 |
9 | Learning Rate Schedulers update the learning rate over the course of training.
10 | Learning rates can be updated after each update via :func:`step_update` or at
11 | epoch boundaries via :func:`step`.
12 |
13 | .. automodule:: fairseq.optim.lr_scheduler
14 | :members:
15 |
16 | .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21 | :members:
22 | :undoc-members:
23 | .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24 | :members:
25 | :undoc-members:
26 | .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27 | :members:
28 | :undoc-members:
29 | .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30 | :members:
31 | :undoc-members:
32 | .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33 | :members:
34 | :undoc-members:
35 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=fairseq
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | Modules
2 | =======
3 |
4 | Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5 | be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6 |
7 | .. automodule:: fairseq.modules
8 | :members:
9 | :undoc-members:
10 |
--------------------------------------------------------------------------------
/docs/optim.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _optimizers:
5 |
6 | Optimizers
7 | ==========
8 |
9 | Optimizers update the Model parameters based on the gradients.
10 |
11 | .. automodule:: fairseq.optim
12 | :members:
13 |
14 | .. autoclass:: fairseq.optim.FairseqOptimizer
15 | :members:
16 | :undoc-members:
17 |
18 | .. autoclass:: fairseq.optim.adadelta.Adadelta
19 | :members:
20 | :undoc-members:
21 | .. autoclass:: fairseq.optim.adagrad.Adagrad
22 | :members:
23 | :undoc-members:
24 | .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25 | :members:
26 | :undoc-members:
27 | .. autoclass:: fairseq.optim.adam.FairseqAdam
28 | :members:
29 | :undoc-members:
30 | .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31 | :members:
32 | :undoc-members:
33 | .. autoclass:: fairseq.optim.nag.FairseqNAG
34 | :members:
35 | :undoc-members:
36 | .. autoclass:: fairseq.optim.sgd.SGD
37 | :members:
38 | :undoc-members:
39 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx<2.0
2 | sphinx-argparse
3 |
--------------------------------------------------------------------------------
/docs/tasks.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.tasks
5 |
6 | .. _Tasks:
7 |
8 | Tasks
9 | =====
10 |
11 | Tasks store dictionaries and provide helpers for loading/iterating over
12 | Datasets, initializing the Model/Criterion and calculating the loss.
13 |
14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a
15 | task may expose additional command-line arguments for further configuration.
16 |
17 | Example usage::
18 |
19 | # setup the task (e.g., load dictionaries)
20 | task = fairseq.tasks.setup_task(args)
21 |
22 | # build model and criterion
23 | model = task.build_model(args)
24 | criterion = task.build_criterion(args)
25 |
26 | # load datasets
27 | task.load_dataset('train')
28 | task.load_dataset('valid')
29 |
30 | # iterate over mini-batches of data
31 | batch_itr = task.get_batch_iterator(
32 | task.dataset('train'), max_tokens=4096,
33 | )
34 | for batch in batch_itr:
35 | # compute the loss
36 | loss, sample_size, logging_output = task.get_loss(
37 | model, criterion, batch,
38 | )
39 | loss.backward()
40 |
41 |
42 | Translation
43 | -----------
44 |
45 | .. autoclass:: fairseq.tasks.translation.TranslationTask
46 |
47 | .. _language modeling:
48 |
49 | Language Modeling
50 | -----------------
51 |
52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
53 |
54 |
55 | Adding new tasks
56 | ----------------
57 |
58 | .. autofunction:: fairseq.tasks.register_task
59 | .. autoclass:: fairseq.tasks.FairseqTask
60 | :members:
61 | :undoc-members:
62 |
--------------------------------------------------------------------------------
/examples/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/examples/.DS_Store
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | !*/*.sh
2 | !*/*.md
3 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | __version__ = '0.9.0'
7 |
8 | import examples.noisychannel # noqa
9 |
--------------------------------------------------------------------------------
/examples/roberta/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/examples/roberta/.DS_Store
--------------------------------------------------------------------------------
/examples/roberta/commonsense_qa/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import commonsense_qa_task # noqa
7 |
--------------------------------------------------------------------------------
/examples/roberta/commonsense_qa/download_cqa_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | OUTDIR=data/CommonsenseQA
8 |
9 | mkdir -p $OUTDIR
10 |
11 | wget -O $OUTDIR/train.jsonl https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl
12 | wget -O $OUTDIR/valid.jsonl https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl
13 | wget -O $OUTDIR/test.jsonl https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl
14 | wget -O $OUTDIR/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
15 |
--------------------------------------------------------------------------------
/examples/roberta/train_base_to_base_plus.sh:
--------------------------------------------------------------------------------
1 | TOTAL_UPDATES=125000 # Total number of training steps
2 | TOTAL_UPDATES_DISTIL=55000 # Total number of distillation training steps
3 | WARMUP_UPDATES=11000 # Warmup the learning rate over this many updates
4 | PEAK_LR=0.00035 # Peak learning rate, adjust as needed
5 | TOKENS_PER_SAMPLE=512 # Max sequence length
6 | MAX_POSITIONS=512 # Num. positional embeddings (usually same as above)
7 | MAX_SENTENCES=16 # Number of sequences per batch (batch size)
8 | UPDATE_FREQ=16 # Increase the batch size 16x
9 | arch=roberta_base_plus
10 | arch_distil_from=roberta_base
11 | restore_file_distil_from=***your-teacher-model-path***
12 | restore_file_checkpoint_distil_from=checkpoint_last.pt
13 | logdir=log_base_to_base_plus
14 | save_dir=checkpoint_base_to_base_plus
15 | DATA_DIR=data-bin/corpus_all
16 |
17 | python ../../fairseq_cli/train.py --fp16 $DATA_DIR \
18 | --task back_distil --criterion masked_lm_distil \
19 | --arch $arch --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \
20 | --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
21 | --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
22 | --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
23 | --batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \
24 | --max-update $TOTAL_UPDATES --log-format json --log-interval 100 \
25 | --max-update-distil $TOTAL_UPDATES_DISTIL \
26 | --tensorboard-logdir $logdir \
27 | --skip-invalid-size-inputs-valid-test \
28 | --save-dir $save_dir \
29 | --fixed-validation-seed 0 \
30 | --ddp-backend no_c10d \
31 | --arch_distil_from $arch_distil_from \
32 | --restore-file-distil-from $restore_file_distil_from \
33 | --restore-file-checkpoint-distil-from $restore_file_checkpoint_distil_from \
34 | --temperature_distil 2 \
35 | --restrict_ce_to_mask \
36 | --save-interval-updates 2500
37 |
--------------------------------------------------------------------------------
/examples/roberta/wsc/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import wsc_criterion # noqa
7 | from . import wsc_task # noqa
8 |
--------------------------------------------------------------------------------
/fairseq.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/fairseq.egg-info/entry_points.txt:
--------------------------------------------------------------------------------
1 | [console_scripts]
2 | fairseq-eval-lm = fairseq_cli.eval_lm:cli_main
3 | fairseq-generate = fairseq_cli.generate:cli_main
4 | fairseq-interactive = fairseq_cli.interactive:cli_main
5 | fairseq-preprocess = fairseq_cli.preprocess:cli_main
6 | fairseq-score = fairseq_cli.score:main
7 | fairseq-train = fairseq_cli.train:cli_main
8 | fairseq-validate = fairseq_cli.validate:cli_main
9 |
10 |
--------------------------------------------------------------------------------
/fairseq.egg-info/not-zip-safe:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/fairseq.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | cffi
2 | cython
3 | numpy
4 | regex
5 | sacrebleu
6 | torch
7 | tqdm
8 |
--------------------------------------------------------------------------------
/fairseq.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | examples
2 | fairseq
3 | fairseq_cli
4 | tests
5 |
--------------------------------------------------------------------------------
/fairseq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | __all__ = ['pdb']
7 | __version__ = '0.9.0'
8 |
9 | import fairseq.criterions # noqa
10 | import fairseq.models # noqa
11 | import fairseq.modules # noqa
12 | import fairseq.optim # noqa
13 | import fairseq.optim.lr_scheduler # noqa
14 | import fairseq.pdb # noqa
15 | import fairseq.tasks # noqa
16 |
--------------------------------------------------------------------------------
/fairseq/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/binarizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/binarizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/checkpoint_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/checkpoint_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/distributed_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/distributed_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/file_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/file_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/hub_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/hub_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/iterative_refinement_generator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/iterative_refinement_generator.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/meters.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/meters.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/options.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/pdb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/pdb.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/progress_bar.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/progress_bar.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/registry.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/registry.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/search.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/search.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/sequence_generator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/sequence_generator.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/clib/libbleu/module.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2017-present, Facebook, Inc.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 |
11 |
12 | static PyMethodDef method_def[] = {
13 | {NULL, NULL, 0, NULL}
14 | };
15 |
16 | static struct PyModuleDef module_def = {
17 | PyModuleDef_HEAD_INIT,
18 | "libbleu", /* name of module */
19 | NULL, /* module documentation, may be NULL */
20 | -1, /* size of per-interpreter state of the module,
21 | or -1 if the module keeps state in global variables. */
22 | method_def
23 | };
24 |
25 |
26 | #if PY_MAJOR_VERSION == 2
27 | PyMODINIT_FUNC init_libbleu()
28 | #else
29 | PyMODINIT_FUNC PyInit_libbleu()
30 | #endif
31 | {
32 | PyObject *m = PyModule_Create(&module_def);
33 | if (!m) {
34 | return NULL;
35 | }
36 | return m;
37 | }
38 |
--------------------------------------------------------------------------------
/fairseq/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from fairseq import registry
10 | from fairseq.criterions.fairseq_criterion import FairseqCriterion
11 |
12 |
13 | build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
14 | '--criterion',
15 | base_class=FairseqCriterion,
16 | default='cross_entropy',
17 | )
18 |
19 |
20 | # automatically import any Python files in the criterions/ directory
21 | for file in os.listdir(os.path.dirname(__file__)):
22 | if file.endswith('.py') and not file.startswith('_'):
23 | module = file[:file.find('.py')]
24 | importlib.import_module('fairseq.criterions.' + module)
25 |
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/adaptive_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/adaptive_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/binary_cross_entropy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/binary_cross_entropy.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/composite_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/composite_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/cross_entropy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/cross_entropy.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/fairseq_criterion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/fairseq_criterion.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/masked_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/masked_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/masked_lm_distil.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/masked_lm_distil.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/masked_lm_distil_H_half.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/masked_lm_distil_H_half.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/nat_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/nat_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/sentence_prediction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/sentence_prediction.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/__pycache__/sentence_ranking.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/criterions/__pycache__/sentence_ranking.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/criterions/fairseq_criterion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torch.nn.modules.loss import _Loss
7 |
8 |
9 | class FairseqCriterion(_Loss):
10 |
11 | def __init__(self, args, task):
12 | super().__init__()
13 | self.args = args
14 | self.task = task
15 | self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add criterion-specific arguments to the parser."""
20 | pass
21 |
22 | @classmethod
23 | def build_criterion(cls, args, task):
24 | return cls(args, task)
25 |
26 | def forward(self, model, sample, reduce=True):
27 | """Compute the loss for the given sample.
28 |
29 | Returns a tuple with three elements:
30 | 1) the loss
31 | 2) the sample size, which is used as the denominator for the gradient
32 | 3) logging outputs to display while training
33 | """
34 | raise NotImplementedError
35 |
36 | @staticmethod
37 | def aggregate_logging_outputs(logging_outputs):
38 | """Aggregate logging outputs from data parallel training."""
39 | raise NotImplementedError
40 |
41 | @staticmethod
42 | def grad_denom(sample_sizes):
43 | """Compute the gradient denominator for a set of sample sizes."""
44 | return sum(sample_sizes)
45 |
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/append_token_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/append_token_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/backtranslation_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/backtranslation_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/base_wrapper_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/base_wrapper_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/colorize_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/colorize_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/concat_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/concat_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/concat_sentences_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/concat_sentences_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/data_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/data_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/denoising_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/denoising_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/dictionary.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/dictionary.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/fairseq_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/fairseq_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/id_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/id_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/indexed_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/indexed_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/iterators.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/iterators.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/language_pair_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/language_pair_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/list_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/list_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/lm_context_window_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/lm_context_window_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/lru_cache_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/lru_cache_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/mask_tokens_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/mask_tokens_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/monolingual_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/monolingual_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/noising.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/noising.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/num_samples_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/num_samples_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/numel_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/numel_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/offset_tokens_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/offset_tokens_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/pad_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/pad_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/plasma_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/plasma_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/prepend_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/prepend_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/prepend_token_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/prepend_token_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/raw_label_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/raw_label_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/replace_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/replace_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/resampling_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/resampling_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/roll_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/roll_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/sharded_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/sharded_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/sort_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/sort_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/strip_token_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/strip_token_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/subsample_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/subsample_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/token_block_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/token_block_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/transform_eos_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/transform_eos_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/__pycache__/truncate_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/__pycache__/truncate_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/append_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class AppendTokenDataset(BaseWrapperDataset):
13 |
14 | def __init__(self, dataset, token=None):
15 | super().__init__(dataset)
16 | self.token = token
17 | if token is not None:
18 | self._sizes = np.array(dataset.sizes) + 1
19 | else:
20 | self._sizes = dataset.sizes
21 |
22 | def __getitem__(self, idx):
23 | item = self.dataset[idx]
24 | if self.token is not None:
25 | item = torch.cat([item, item.new([self.token])])
26 | return item
27 |
28 | @property
29 | def sizes(self):
30 | return self._sizes
31 |
32 | def num_tokens(self, index):
33 | n = self.dataset.num_tokens(index)
34 | if self.token is not None:
35 | n += 1
36 | return n
37 |
38 | def size(self, index):
39 | n = self.dataset.size(index)
40 | if self.token is not None:
41 | n += 1
42 | return n
43 |
--------------------------------------------------------------------------------
/fairseq/data/audio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/audio/__init__.py
--------------------------------------------------------------------------------
/fairseq/data/audio/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/audio/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/base_wrapper_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torch.utils.data.dataloader import default_collate
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class BaseWrapperDataset(FairseqDataset):
12 |
13 | def __init__(self, dataset):
14 | super().__init__()
15 | self.dataset = dataset
16 |
17 | def __getitem__(self, index):
18 | return self.dataset[index]
19 |
20 | def __len__(self):
21 | return len(self.dataset)
22 |
23 | def collater(self, samples):
24 | if hasattr(self.dataset, 'collater'):
25 | return self.dataset.collater(samples)
26 | else:
27 | return default_collate(samples)
28 |
29 | @property
30 | def sizes(self):
31 | return self.dataset.sizes
32 |
33 | def num_tokens(self, index):
34 | return self.dataset.num_tokens(index)
35 |
36 | def size(self, index):
37 | return self.dataset.size(index)
38 |
39 | def ordered_indices(self):
40 | return self.dataset.ordered_indices()
41 |
42 | @property
43 | def supports_prefetch(self):
44 | return getattr(self.dataset, 'supports_prefetch', False)
45 |
46 | def prefetch(self, indices):
47 | self.dataset.prefetch(indices)
48 |
49 | def set_epoch(self, epoch):
50 | super().set_epoch(epoch)
51 | if hasattr(self.dataset, 'set_epoch'):
52 | self.dataset.set_epoch(epoch)
53 |
--------------------------------------------------------------------------------
/fairseq/data/colorize_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class ColorizeDataset(BaseWrapperDataset):
12 | """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
13 | def __init__(self, dataset, color_getter):
14 | super().__init__(dataset)
15 | self.color_getter = color_getter
16 |
17 | def collater(self, samples):
18 | base_collate = super().collater(samples)
19 | if len(base_collate) > 0:
20 | base_collate["net_input"]["colors"] = torch.tensor(
21 | list(self.color_getter(self.dataset, s["id"]) for s in samples),
22 | dtype=torch.long,
23 | )
24 | return base_collate
25 |
--------------------------------------------------------------------------------
/fairseq/data/concat_sentences_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class ConcatSentencesDataset(FairseqDataset):
12 |
13 | def __init__(self, *datasets):
14 | super().__init__()
15 | self.datasets = datasets
16 | assert all(len(ds) == len(datasets[0]) for ds in datasets), \
17 | 'datasets must have the same length'
18 |
19 | def __getitem__(self, index):
20 | return torch.cat([ds[index] for ds in self.datasets])
21 |
22 | def __len__(self):
23 | return len(self.datasets[0])
24 |
25 | def collater(self, samples):
26 | return self.datasets[0].collater(samples)
27 |
28 | @property
29 | def sizes(self):
30 | return sum(ds.sizes for ds in self.datasets)
31 |
32 | def num_tokens(self, index):
33 | return sum(ds.num_tokens(index) for ds in self.datasets)
34 |
35 | def size(self, index):
36 | return sum(ds.size(index) for ds in self.datasets)
37 |
38 | def ordered_indices(self):
39 | return self.datasets[0].ordered_indices()
40 |
41 | @property
42 | def supports_prefetch(self):
43 | return any(
44 | getattr(ds, 'supports_prefetch', False) for ds in self.datasets
45 | )
46 |
47 | def prefetch(self, indices):
48 | for ds in self.datasets:
49 | if getattr(ds, 'supports_prefetch', False):
50 | ds.prefetch(indices)
51 |
52 | def set_epoch(self, epoch):
53 | super().set_epoch(epoch)
54 | for ds in self.datasets:
55 | if hasattr(ds, 'set_epoch'):
56 | ds.set_epoch(epoch)
57 |
--------------------------------------------------------------------------------
/fairseq/data/data_utils_fast.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/data_utils_fast.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/fairseq/data/data_utils_fast.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level=3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 |
9 | cimport cython
10 | cimport numpy as np
11 |
12 | DTYPE = np.int64
13 | ctypedef np.int64_t DTYPE_t
14 |
15 |
16 | cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
17 | if len(batch) == 0:
18 | return 0
19 | if max_sentences > 0 and len(batch) == max_sentences:
20 | return 1
21 | if max_tokens > 0 and num_tokens > max_tokens:
22 | return 1
23 | return 0
24 |
25 |
26 | @cython.cdivision(True)
27 | cpdef list batch_by_size_fast(
28 | np.ndarray[DTYPE_t, ndim=1] indices,
29 | num_tokens_fn,
30 | long max_tokens,
31 | long max_sentences,
32 | int bsz_mult,
33 | ):
34 | cdef long sample_len = 0
35 | cdef list sample_lens = []
36 | cdef list batch = []
37 | cdef list batches = []
38 | cdef long mod_len
39 | cdef long i
40 | cdef long idx
41 | cdef long num_tokens
42 | cdef DTYPE_t[:] indices_view = indices
43 |
44 | for i in range(len(indices_view)):
45 | idx = indices_view[i]
46 | num_tokens = num_tokens_fn(idx)
47 | sample_lens.append(num_tokens)
48 | sample_len = max(sample_len, num_tokens)
49 |
50 | assert max_tokens <= 0 or sample_len <= max_tokens, (
51 | "sentence at index {} of size {} exceeds max_tokens "
52 | "limit of {}!".format(idx, sample_len, max_tokens)
53 | )
54 | num_tokens = (len(batch) + 1) * sample_len
55 |
56 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
57 | mod_len = max(
58 | bsz_mult * (len(batch) // bsz_mult),
59 | len(batch) % bsz_mult,
60 | )
61 | batches.append(batch[:mod_len])
62 | batch = batch[mod_len:]
63 | sample_lens = sample_lens[mod_len:]
64 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
65 | batch.append(idx)
66 | if len(batch) > 0:
67 | batches.append(batch)
68 | return batches
69 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import importlib
8 | import os
9 |
10 | from fairseq import registry
11 |
12 |
13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry(
14 | '--tokenizer',
15 | default=None,
16 | )
17 |
18 |
19 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry(
20 | '--bpe',
21 | default=None,
22 | )
23 |
24 |
25 | # automatically import any Python files in the encoders/ directory
26 | for file in os.listdir(os.path.dirname(__file__)):
27 | if file.endswith('.py') and not file.startswith('_'):
28 | module = file[:file.find('.py')]
29 | importlib.import_module('fairseq.data.encoders.' + module)
30 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/fastbpe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/fastbpe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/encoders/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/encoders/fastbpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | @register_bpe('fastbpe')
11 | class fastBPE(object):
12 |
13 | @staticmethod
14 | def add_args(parser):
15 | # fmt: off
16 | parser.add_argument('--bpe-codes', type=str,
17 | help='path to fastBPE BPE')
18 | # fmt: on
19 |
20 | def __init__(self, args):
21 | if args.bpe_codes is None:
22 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
23 | codes = file_utils.cached_path(args.bpe_codes)
24 | try:
25 | import fastBPE
26 | self.bpe = fastBPE.fastBPE(codes)
27 | self.bpe_symbol = "@@ "
28 | except ImportError:
29 | raise ImportError('Please install fastBPE with: pip install fastBPE')
30 |
31 | def encode(self, x: str) -> str:
32 | return self.bpe.apply([x])[0]
33 |
34 | def decode(self, x: str) -> str:
35 | return (x + ' ').replace(self.bpe_symbol, '').rstrip()
36 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/gpt2_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 | from .gpt2_bpe_utils import get_encoder
10 |
11 |
12 | DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13 | DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14 |
15 |
16 | @register_bpe('gpt2')
17 | class GPT2BPE(object):
18 |
19 | @staticmethod
20 | def add_args(parser):
21 | # fmt: off
22 | parser.add_argument('--gpt2-encoder-json', type=str,
23 | default=DEFAULT_ENCODER_JSON,
24 | help='path to encoder.json')
25 | parser.add_argument('--gpt2-vocab-bpe', type=str,
26 | default=DEFAULT_VOCAB_BPE,
27 | help='path to vocab.bpe')
28 | # fmt: on
29 |
30 | def __init__(self, args):
31 | encoder_json = file_utils.cached_path(
32 | getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON)
33 | )
34 | vocab_bpe = file_utils.cached_path(
35 | getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE)
36 | )
37 | self.bpe = get_encoder(encoder_json, vocab_bpe)
38 |
39 | def encode(self, x: str) -> str:
40 | return ' '.join(map(str, self.bpe.encode(x)))
41 |
42 | def decode(self, x: str) -> str:
43 | return self.bpe.decode(map(int, x.split()))
44 |
45 | def is_beginning_of_word(self, x: str) -> bool:
46 | return self.decode(x).startswith(' ')
47 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/hf_bert_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_bpe
7 |
8 |
9 | @register_bpe('bert')
10 | class BertBPE(object):
11 |
12 | @staticmethod
13 | def add_args(parser):
14 | # fmt: off
15 | parser.add_argument('--bpe-cased', action='store_true',
16 | help='set for cased BPE',
17 | default=False)
18 | parser.add_argument('--bpe-vocab-file', type=str,
19 | help='bpe vocab file.')
20 | # fmt: on
21 |
22 | def __init__(self, args):
23 | try:
24 | from pytorch_transformers import BertTokenizer
25 | from pytorch_transformers.tokenization_utils import clean_up_tokenization
26 | except ImportError:
27 | raise ImportError(
28 | 'Please install 1.0.0 version of pytorch_transformers'
29 | 'with: pip install pytorch-transformers'
30 | )
31 |
32 | if 'bpe_vocab_file' in args:
33 | self.bert_tokenizer = BertTokenizer(
34 | args.bpe_vocab_file,
35 | do_lower_case=not args.bpe_cased
36 | )
37 | else:
38 | vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased'
39 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
40 | self.clean_up_tokenization = clean_up_tokenization
41 |
42 | def encode(self, x: str) -> str:
43 | return ' '.join(self.bert_tokenizer.tokenize(x))
44 |
45 | def decode(self, x: str) -> str:
46 | return self.clean_up_tokenization(
47 | self.bert_tokenizer.convert_tokens_to_string(x.split(' '))
48 | )
49 |
50 | def is_beginning_of_word(self, x: str) -> bool:
51 | return not x.startswith('##')
52 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/moses_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_tokenizer
7 |
8 |
9 | @register_tokenizer('moses')
10 | class MosesTokenizer(object):
11 |
12 | @staticmethod
13 | def add_args(parser):
14 | # fmt: off
15 | parser.add_argument('--moses-source-lang', metavar='SRC',
16 | help='source language')
17 | parser.add_argument('--moses-target-lang', metavar='TARGET',
18 | help='target language')
19 | parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
20 | help='don\'t apply dash split rules')
21 | parser.add_argument('--moses-no-escape', action='store_true', default=False,
22 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
23 | # fmt: on
24 |
25 | def __init__(self, args):
26 | self.args = args
27 |
28 | if getattr(args, 'moses_source_lang', None) is None:
29 | args.moses_source_lang = getattr(args, 'source_lang', 'en')
30 | if getattr(args, 'moses_target_lang', None) is None:
31 | args.moses_target_lang = getattr(args, 'target_lang', 'en')
32 |
33 | try:
34 | from sacremoses import MosesTokenizer, MosesDetokenizer
35 | self.tok = MosesTokenizer(args.moses_source_lang)
36 | self.detok = MosesDetokenizer(args.moses_target_lang)
37 | except ImportError:
38 | raise ImportError('Please install Moses tokenizer with: pip install sacremoses')
39 |
40 | def encode(self, x: str) -> str:
41 | return self.tok.tokenize(
42 | x,
43 | aggressive_dash_splits=(not self.args.moses_no_dash_splits),
44 | return_str=True,
45 | escape=(not self.args.moses_no_escape),
46 | )
47 |
48 | def decode(self, x: str) -> str:
49 | return self.detok.detokenize(x.split())
50 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/nltk_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_tokenizer
7 |
8 |
9 | @register_tokenizer('nltk')
10 | class NLTKTokenizer(object):
11 |
12 | def __init__(self, source_lang=None, target_lang=None):
13 | try:
14 | from nltk.tokenize import word_tokenize
15 | self.word_tokenize = word_tokenize
16 | except ImportError:
17 | raise ImportError('Please install nltk with: pip install nltk')
18 |
19 | def encode(self, x: str) -> str:
20 | return ' '.join(self.word_tokenize(x))
21 |
22 | def decode(self, x: str) -> str:
23 | return x
24 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/sentencepiece_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | @register_bpe('sentencepiece')
11 | class SentencepieceBPE(object):
12 |
13 | @staticmethod
14 | def add_args(parser):
15 | # fmt: off
16 | parser.add_argument('--sentencepiece-vocab', type=str,
17 | help='path to sentencepiece vocab')
18 | # fmt: on
19 |
20 | def __init__(self, args):
21 | vocab = file_utils.cached_path(args.sentencepiece_vocab)
22 | try:
23 | import sentencepiece as spm
24 | self.sp = spm.SentencePieceProcessor()
25 | self.sp.Load(vocab)
26 | except ImportError:
27 | raise ImportError('Please install sentencepiece with: pip install sentencepiece')
28 |
29 | def encode(self, x: str) -> str:
30 | return ' '.join(self.sp.EncodeAsPieces(x))
31 |
32 | def decode(self, x: str) -> str:
33 | return x.replace(' ', '').replace('\u2581', ' ').strip()
34 |
35 | def is_beginning_of_word(self, x: str) -> bool:
36 | if x in ['', '', '', '']:
37 | # special elements are always considered beginnings
38 | # HACK: this logic is already present in fairseq/tasks/masked_lm.py
39 | # but these special tokens are also contained in the sentencepiece
40 | # vocabulary which causes duplicate special tokens. This hack makes
41 | # sure that they are all taken into account.
42 | return True
43 | return x.startswith('\u2581')
44 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/space_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import re
7 |
8 | from fairseq.data.encoders import register_tokenizer
9 |
10 |
11 | @register_tokenizer('space')
12 | class SpaceTokenizer(object):
13 |
14 | def __init__(self, source_lang=None, target_lang=None):
15 | self.space_tok = re.compile(r"\s+")
16 |
17 | def encode(self, x: str) -> str:
18 | return self.space_tok.sub(' ', x)
19 |
20 | def decode(self, x: str) -> str:
21 | return x
22 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/subword_nmt_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq import file_utils
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | @register_bpe('subword_nmt')
11 | class SubwordNMTBPE(object):
12 |
13 | @staticmethod
14 | def add_args(parser):
15 | # fmt: off
16 | parser.add_argument('--bpe-codes', type=str,
17 | help='path to subword NMT BPE')
18 | parser.add_argument('--bpe-separator', default='@@',
19 | help='BPE separator')
20 | # fmt: on
21 |
22 | def __init__(self, args):
23 | if args.bpe_codes is None:
24 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
25 | codes = file_utils.cached_path(args.bpe_codes)
26 | try:
27 | from subword_nmt import apply_bpe
28 | bpe_parser = apply_bpe.create_parser()
29 | bpe_args = bpe_parser.parse_args([
30 | '--codes', codes,
31 | '--separator', args.bpe_separator,
32 | ])
33 | self.bpe = apply_bpe.BPE(
34 | bpe_args.codes,
35 | bpe_args.merges,
36 | bpe_args.separator,
37 | None,
38 | bpe_args.glossaries,
39 | )
40 | self.bpe_symbol = bpe_args.separator + ' '
41 | except ImportError:
42 | raise ImportError('Please install subword_nmt with: pip install subword-nmt')
43 |
44 | def encode(self, x: str) -> str:
45 | return self.bpe.process_line(x)
46 |
47 | def decode(self, x: str) -> str:
48 | return (x + ' ').replace(self.bpe_symbol, '').rstrip()
49 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from fairseq.data import encoders
8 |
9 |
10 | def get_whole_word_mask(args, dictionary):
11 | bpe = encoders.build_bpe(args)
12 | if bpe is not None:
13 | def is_beginning_of_word(i):
14 | if i < dictionary.nspecial:
15 | # special elements are always considered beginnings
16 | return True
17 | tok = dictionary[i]
18 | if tok.startswith('madeupword'):
19 | return True
20 | try:
21 | return bpe.is_beginning_of_word(tok)
22 | except ValueError:
23 | return True
24 | mask_whole_words = torch.ByteTensor(list(
25 | map(is_beginning_of_word, range(len(dictionary)))
26 | ))
27 | return mask_whole_words
28 | return None
29 |
--------------------------------------------------------------------------------
/fairseq/data/id_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class IdDataset(FairseqDataset):
12 |
13 | def __getitem__(self, index):
14 | return index
15 |
16 | def __len__(self):
17 | return 0
18 |
19 | def collater(self, samples):
20 | return torch.tensor(samples)
21 |
--------------------------------------------------------------------------------
/fairseq/data/legacy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
7 | from .block_pair_dataset import BlockPairDataset
8 | from .masked_lm_dataset import MaskedLMDataset
9 |
10 | __all__ = [
11 | 'BertDictionary',
12 | 'BlockPairDataset',
13 | 'MaskedLMDataset',
14 | 'MaskedLMDictionary',
15 | ]
16 |
--------------------------------------------------------------------------------
/fairseq/data/legacy/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/legacy/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/data/legacy/masked_lm_dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data import Dictionary
7 |
8 |
9 | class MaskedLMDictionary(Dictionary):
10 | """
11 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by
12 | adding the mask symbol.
13 | """
14 | def __init__(
15 | self,
16 | pad='',
17 | eos='',
18 | unk='',
19 | mask='',
20 | ):
21 | super().__init__(pad, eos, unk)
22 | self.mask_word = mask
23 | self.mask_index = self.add_symbol(mask)
24 | self.nspecial = len(self.symbols)
25 |
26 | def mask(self):
27 | """Helper to get index of mask symbol"""
28 | return self.mask_index
29 |
30 |
31 | class BertDictionary(MaskedLMDictionary):
32 | """
33 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support
34 | for cls and sep symbols.
35 | """
36 | def __init__(
37 | self,
38 | pad='',
39 | eos='',
40 | unk='',
41 | mask='',
42 | cls='',
43 | sep=''
44 | ):
45 | super().__init__(pad, eos, unk, mask)
46 | self.cls_word = cls
47 | self.sep_word = sep
48 | self.cls_index = self.add_symbol(cls)
49 | self.sep_index = self.add_symbol(sep)
50 | self.nspecial = len(self.symbols)
51 |
52 | def cls(self):
53 | """Helper to get index of cls symbol"""
54 | return self.cls_index
55 |
56 | def sep(self):
57 | """Helper to get index of sep symbol"""
58 | return self.sep_index
59 |
--------------------------------------------------------------------------------
/fairseq/data/list_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class ListDataset(BaseWrapperDataset):
10 |
11 | def __init__(self, dataset, sizes=None):
12 | super().__init__(dataset)
13 | self._sizes = sizes
14 |
15 | def __iter__(self):
16 | for x in self.dataset:
17 | yield x
18 |
19 | def collater(self, samples):
20 | return samples
21 |
22 | @property
23 | def sizes(self):
24 | return self._sizes
25 |
26 | def num_tokens(self, index):
27 | return self.sizes[index]
28 |
29 | def size(self, index):
30 | return self.sizes[index]
31 |
32 | def set_epoch(self, epoch):
33 | pass
34 |
--------------------------------------------------------------------------------
/fairseq/data/lru_cache_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from functools import lru_cache
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class LRUCacheDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, token=None):
14 | super().__init__(dataset)
15 |
16 | @lru_cache(maxsize=8)
17 | def __getitem__(self, index):
18 | return self.dataset[index]
19 |
20 | @lru_cache(maxsize=8)
21 | def collater(self, samples):
22 | return self.dataset.collater(samples)
23 |
--------------------------------------------------------------------------------
/fairseq/data/num_samples_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import FairseqDataset
7 |
8 |
9 | class NumSamplesDataset(FairseqDataset):
10 |
11 | def __getitem__(self, index):
12 | return 1
13 |
14 | def __len__(self):
15 | return 0
16 |
17 | def collater(self, samples):
18 | return sum(samples)
19 |
--------------------------------------------------------------------------------
/fairseq/data/numel_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class NumelDataset(BaseWrapperDataset):
13 |
14 | def __init__(self, dataset, reduce=False):
15 | super().__init__(dataset)
16 | self.reduce = reduce
17 |
18 | def __getitem__(self, index):
19 | item = self.dataset[index]
20 | if torch.is_tensor(item):
21 | return torch.numel(item)
22 | else:
23 | return np.size(item)
24 |
25 | def __len__(self):
26 | return len(self.dataset)
27 |
28 | def collater(self, samples):
29 | if self.reduce:
30 | return sum(samples)
31 | else:
32 | return torch.tensor(samples)
33 |
--------------------------------------------------------------------------------
/fairseq/data/offset_tokens_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class OffsetTokensDataset(BaseWrapperDataset):
10 |
11 | def __init__(self, dataset, offset):
12 | super().__init__(dataset)
13 | self.offset = offset
14 |
15 | def __getitem__(self, idx):
16 | return self.dataset[idx] + self.offset
17 |
--------------------------------------------------------------------------------
/fairseq/data/pad_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data import data_utils
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class PadDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, pad_idx, left_pad):
14 | super().__init__(dataset)
15 | self.pad_idx = pad_idx
16 | self.left_pad = left_pad
17 |
18 | def collater(self, samples):
19 | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
20 |
21 |
22 | class LeftPadDataset(PadDataset):
23 |
24 | def __init__(self, dataset, pad_idx):
25 | super().__init__(dataset, pad_idx, left_pad=True)
26 |
27 |
28 | class RightPadDataset(PadDataset):
29 |
30 | def __init__(self, dataset, pad_idx):
31 | super().__init__(dataset, pad_idx, left_pad=False)
32 |
--------------------------------------------------------------------------------
/fairseq/data/prepend_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class PrependDataset(BaseWrapperDataset):
13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
14 | super().__init__(dataset)
15 | self.prepend_getter = prepend_getter
16 | self.ensure_first_token = ensure_first_token_is
17 |
18 | def __getitem__(self, idx):
19 | item = self.dataset[idx]
20 | is_tuple = isinstance(item, tuple)
21 | src = item[0] if is_tuple else item
22 |
23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token
24 | prepend_idx = self.prepend_getter(self.dataset, idx)
25 | assert isinstance(prepend_idx, int)
26 | src[0] = prepend_idx
27 | item = tuple((src,) + item[1:]) if is_tuple else src
28 | return item
29 |
--------------------------------------------------------------------------------
/fairseq/data/prepend_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class PrependTokenDataset(BaseWrapperDataset):
13 |
14 | def __init__(self, dataset, token=None):
15 | super().__init__(dataset)
16 | self.token = token
17 | if token is not None:
18 | self._sizes = np.array(dataset.sizes) + 1
19 | else:
20 | self._sizes = dataset.sizes
21 |
22 | def __getitem__(self, idx):
23 | item = self.dataset[idx]
24 | if self.token is not None:
25 | item = torch.cat([item.new([self.token]), item])
26 | return item
27 |
28 | @property
29 | def sizes(self):
30 | return self._sizes
31 |
32 | def num_tokens(self, index):
33 | n = self.dataset.num_tokens(index)
34 | if self.token is not None:
35 | n += 1
36 | return n
37 |
38 | def size(self, index):
39 | n = self.dataset.size(index)
40 | if self.token is not None:
41 | n += 1
42 | return n
43 |
--------------------------------------------------------------------------------
/fairseq/data/raw_label_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class RawLabelDataset(FairseqDataset):
12 |
13 | def __init__(self, labels):
14 | super().__init__()
15 | self.labels = labels
16 |
17 | def __getitem__(self, index):
18 | return self.labels[index]
19 |
20 | def __len__(self):
21 | return len(self.labels)
22 |
23 | def collater(self, samples):
24 | return torch.tensor(samples)
25 |
--------------------------------------------------------------------------------
/fairseq/data/replace_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class ReplaceDataset(BaseWrapperDataset):
10 | """Replaces tokens found in the dataset by a specified replacement token
11 |
12 | Args:
13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in
14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token
15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
16 | as many as the number of objects returned by the underlying dataset __getitem__ method.
17 | """
18 |
19 | def __init__(self, dataset, replace_map, offsets):
20 | super().__init__(dataset)
21 | assert len(replace_map) > 0
22 | self.replace_map = replace_map
23 | self.offsets = offsets
24 |
25 | def __getitem__(self, index):
26 | item = self.dataset[index]
27 | is_tuple = isinstance(item, tuple)
28 | srcs = item if is_tuple else [item]
29 |
30 | for offset, src in zip(self.offsets, srcs):
31 | for k, v in self.replace_map.items():
32 | src_off = src[offset:] if offset >= 0 else src[:offset]
33 | src_off.masked_fill_(src_off == k, v)
34 |
35 | item = srcs if is_tuple else srcs[0]
36 | return item
37 |
--------------------------------------------------------------------------------
/fairseq/data/roll_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class RollDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, shifts):
14 | super().__init__(dataset)
15 | self.shifts = shifts
16 |
17 | def __getitem__(self, index):
18 | item = self.dataset[index]
19 | return torch.roll(item, self.shifts)
20 |
--------------------------------------------------------------------------------
/fairseq/data/sharded_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import itertools
7 | import os
8 | import random
9 |
10 | from . import BaseWrapperDataset
11 | from fairseq.data import data_utils
12 |
13 |
14 | class ShardedDataset(BaseWrapperDataset):
15 | """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
16 |
17 | Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch
18 |
19 | """
20 |
21 | def __init__(
22 | self,
23 | dictionary,
24 | dataset_impl: str,
25 | path: str,
26 | split: str,
27 | epoch: int,
28 | name: str = None,
29 | combine: bool = False,
30 | seed: int = 0,
31 | ):
32 | self._name = name if name is not None else os.path.basename(path)
33 | num_shards = 0
34 | for i in itertools.count():
35 | if not os.path.exists(os.path.join(path, "shard" + str(i))):
36 | break
37 | num_shards += 1
38 |
39 | if num_shards > 0 and split == "train":
40 | random.seed(seed ^ epoch)
41 | shard = random.randint(0, num_shards - 1)
42 | split_path = os.path.join(path, "shard" + str(shard), split)
43 | else:
44 | split_path = os.path.join(path, split)
45 | if os.path.isdir(split_path):
46 | split_path = os.path.join(split_path, split)
47 |
48 | dataset = data_utils.load_indexed_dataset(
49 | split_path, dictionary, dataset_impl, combine=combine
50 | )
51 | if dataset is None:
52 | raise FileNotFoundError(
53 | "Dataset not found: {} ({})".format(split, split_path)
54 | )
55 |
56 | super().__init__(dataset)
57 |
58 | @property
59 | def name(self):
60 | return self._name
61 |
--------------------------------------------------------------------------------
/fairseq/data/sort_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class SortDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, sort_order):
14 | super().__init__(dataset)
15 | if not isinstance(sort_order, (list, tuple)):
16 | sort_order = [sort_order]
17 | self.sort_order = sort_order
18 |
19 | assert all(len(so) == len(dataset) for so in sort_order)
20 |
21 | def ordered_indices(self):
22 | return np.lexsort(self.sort_order)
23 |
--------------------------------------------------------------------------------
/fairseq/data/strip_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class StripTokenDataset(BaseWrapperDataset):
10 |
11 | def __init__(self, dataset, id_to_strip):
12 | super().__init__(dataset)
13 | self.id_to_strip = id_to_strip
14 |
15 | def __getitem__(self, index):
16 | item = self.dataset[index]
17 | return item[item.ne(self.id_to_strip)]
18 |
--------------------------------------------------------------------------------
/fairseq/data/subsample_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class SubsampleDataset(BaseWrapperDataset):
12 | """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
13 |
14 | Args:
15 | dataset (~torch.utils.data.Dataset): dataset to subsample
16 | size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
17 | """
18 |
19 | def __init__(self, dataset, size_ratio):
20 | super().__init__(dataset)
21 | assert size_ratio < 1
22 | self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
23 | self.indices = np.random.choice(
24 | list(range(len(self.dataset))), self.actual_size, replace=False
25 | )
26 | print(
27 | "subsampled dataset from {} to {} (ratio={})".format(
28 | len(self.dataset), self.actual_size, size_ratio
29 | )
30 | )
31 |
32 | def __getitem__(self, index):
33 | return self.dataset[self.indices[index]]
34 |
35 | def __len__(self):
36 | return self.actual_size
37 |
38 | def collater(self, samples):
39 | return self.dataset.collater(samples)
40 |
41 | @property
42 | def sizes(self):
43 | return self.dataset.sizes[self.indices]
44 |
45 | @property
46 | def name(self):
47 | return self.dataset.name
48 |
49 | def num_tokens(self, index):
50 | return self.dataset.num_tokens(self.indices[index])
51 |
52 | def size(self, index):
53 | return self.dataset.size(self.indices[index])
54 |
55 | def ordered_indices(self):
56 | """Return an ordered list of indices. Batches will be constructed based
57 | on this order."""
58 | if self.shuffle:
59 | order = [np.random.permutation(len(self))]
60 | else:
61 | order = [np.arange(len(self))]
62 | order.append(self.sizes)
63 | return np.lexsort(order)
64 |
65 | def prefetch(self, indices):
66 | self.dataset.prefetch(self.indices[indices])
67 |
--------------------------------------------------------------------------------
/fairseq/data/token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/data/token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/fairseq/data/truncate_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class TruncateDataset(BaseWrapperDataset):
12 |
13 | def __init__(self, dataset, truncation_length):
14 | super().__init__(dataset)
15 | assert truncation_length is not None
16 | self.truncation_length = truncation_length
17 | self.dataset = dataset
18 |
19 | def __getitem__(self, index):
20 | item = self.dataset[index]
21 | item_len = item.size(0)
22 | if item_len > self.truncation_length:
23 | item = item[:self.truncation_length]
24 | return item
25 |
26 | @property
27 | def sizes(self):
28 | return np.minimum(self.dataset.sizes, self.truncation_length)
29 |
30 | def __len__(self):
31 | return len(self.dataset)
32 |
--------------------------------------------------------------------------------
/fairseq/libbleu.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/libbleu.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/fairseq/libnat.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/libnat.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/fairseq/meters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import time
7 |
8 |
9 | class AverageMeter(object):
10 | """Computes and stores the average and current value"""
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 |
27 | class TimeMeter(object):
28 | """Computes the average occurrence of some event per second"""
29 | def __init__(self, init=0):
30 | self.reset(init)
31 |
32 | def reset(self, init=0):
33 | self.init = init
34 | self.start = time.time()
35 | self.n = 0
36 |
37 | def update(self, val=1):
38 | self.n += val
39 |
40 | @property
41 | def avg(self):
42 | return self.n / self.elapsed_time
43 |
44 | @property
45 | def elapsed_time(self):
46 | return self.init + (time.time() - self.start)
47 |
48 |
49 | class StopwatchMeter(object):
50 | """Computes the sum/avg duration of some event in seconds"""
51 | def __init__(self):
52 | self.reset()
53 |
54 | def start(self):
55 | self.start_time = time.time()
56 |
57 | def stop(self, n=1):
58 | if self.start_time is not None:
59 | delta = time.time() - self.start_time
60 | self.sum += delta
61 | self.n += n
62 | self.start_time = None
63 |
64 | def reset(self):
65 | self.sum = 0
66 | self.n = 0
67 | self.start_time = None
68 |
69 | @property
70 | def avg(self):
71 | return self.sum / self.n
72 |
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/cmlm_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/cmlm_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/composite_encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/composite_encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/distributed_fairseq_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/distributed_fairseq_model.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fairseq_decoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fairseq_decoder.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fairseq_encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fairseq_encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fairseq_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fairseq_model.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fconv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fconv.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fconv_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fconv_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/fconv_self_att.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/fconv_self_att.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/insertion_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/insertion_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/iterative_nonautoregressive_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/iterative_nonautoregressive_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/levenshtein_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/levenshtein_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/lightconv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/lightconv.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/lightconv_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/lightconv_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/lstm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/lstm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/masked_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/masked_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/model_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/model_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/multilingual_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/multilingual_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/nonautoregressive_ensembles.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/nonautoregressive_ensembles.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/nonautoregressive_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/nonautoregressive_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/transformer_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/transformer_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/__pycache__/wav2vec.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/__pycache__/wav2vec.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/bart/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .hub_interface import * # noqa
7 | from .model import * # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/models/bart/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/bart/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/bart/__pycache__/hub_interface.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/bart/__pycache__/hub_interface.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/bart/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/bart/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/composite_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.models import FairseqEncoder
7 |
8 |
9 | class CompositeEncoder(FairseqEncoder):
10 | """
11 | A wrapper around a dictionary of :class:`FairseqEncoder` objects.
12 |
13 | We run forward on each encoder and return a dictionary of outputs. The first
14 | encoder's dictionary is used for initialization.
15 |
16 | Args:
17 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
18 | """
19 |
20 | def __init__(self, encoders):
21 | super().__init__(next(iter(encoders.values())).dictionary)
22 | self.encoders = encoders
23 | for key in self.encoders:
24 | self.add_module(key, self.encoders[key])
25 |
26 | def forward(self, src_tokens, src_lengths):
27 | """
28 | Args:
29 | src_tokens (LongTensor): tokens in the source language of shape
30 | `(batch, src_len)`
31 | src_lengths (LongTensor): lengths of each source sentence of shape
32 | `(batch)`
33 |
34 | Returns:
35 | dict:
36 | the outputs from each Encoder
37 | """
38 | encoder_out = {}
39 | for key in self.encoders:
40 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
41 | return encoder_out
42 |
43 | def reorder_encoder_out(self, encoder_out, new_order):
44 | """Reorder encoder output according to new_order."""
45 | for key in self.encoders:
46 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order)
47 | return encoder_out
48 |
49 | def max_positions(self):
50 | return min([self.encoders[key].max_positions() for key in self.encoders])
51 |
52 | def upgrade_state_dict(self, state_dict):
53 | for key in self.encoders:
54 | self.encoders[key].upgrade_state_dict(state_dict)
55 | return state_dict
56 |
--------------------------------------------------------------------------------
/fairseq/models/fairseq_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 |
9 | class FairseqEncoder(nn.Module):
10 | """Base class for encoders."""
11 |
12 | def __init__(self, dictionary):
13 | super().__init__()
14 | self.dictionary = dictionary
15 |
16 | def forward(self, src_tokens, src_lengths=None, **kwargs):
17 | """
18 | Args:
19 | src_tokens (LongTensor): tokens in the source language of shape
20 | `(batch, src_len)`
21 | src_lengths (LongTensor): lengths of each source sentence of shape
22 | `(batch)`
23 | """
24 | raise NotImplementedError
25 |
26 | def reorder_encoder_out(self, encoder_out, new_order):
27 | """
28 | Reorder encoder output according to `new_order`.
29 |
30 | Args:
31 | encoder_out: output from the ``forward()`` method
32 | new_order (LongTensor): desired order
33 |
34 | Returns:
35 | `encoder_out` rearranged according to `new_order`
36 | """
37 | raise NotImplementedError
38 |
39 | def max_positions(self):
40 | """Maximum input length supported by the encoder."""
41 | return 1e6 # an arbitrary large number
42 |
43 | def upgrade_state_dict(self, state_dict):
44 | """Upgrade a (possibly old) state dict for new versions of fairseq."""
45 | return state_dict
46 |
--------------------------------------------------------------------------------
/fairseq/models/roberta/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .hub_interface import * # noqa
7 | from .model import * # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/models/roberta/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/roberta/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/roberta/__pycache__/hub_interface.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/roberta/__pycache__/hub_interface.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/models/roberta/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/models/roberta/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .adaptive_input import AdaptiveInput
7 | from .adaptive_softmax import AdaptiveSoftmax
8 | from .beamable_mm import BeamableMM
9 | from .character_token_embedder import CharacterTokenEmbedder
10 | from .conv_tbc import ConvTBC
11 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention
12 | from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
13 | from .gelu import gelu, gelu_accurate
14 | from .grad_multiply import GradMultiply
15 | from .highway import Highway
16 | from .layer_norm import LayerNorm
17 | from .learned_positional_embedding import LearnedPositionalEmbedding
18 | from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
19 | from .linearized_convolution import LinearizedConvolution
20 | from .logsumexp_moe import LogSumExpMoE
21 | from .mean_pool_gating_network import MeanPoolGatingNetwork
22 | from .multihead_attention import MultiheadAttention
23 | from .positional_embedding import PositionalEmbedding
24 | from .scalar_bias import ScalarBias
25 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
26 | from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
27 | from .transformer_sentence_encoder import TransformerSentenceEncoder
28 | from .unfold import unfold1d
29 | from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
30 | from .vggblock import VGGBlock
31 |
32 | __all__ = [
33 | 'AdaptiveInput',
34 | 'AdaptiveSoftmax',
35 | 'BeamableMM',
36 | 'CharacterTokenEmbedder',
37 | 'ConvTBC',
38 | 'DownsampledMultiHeadAttention',
39 | 'DynamicConv1dTBC',
40 | 'DynamicConv',
41 | 'gelu',
42 | 'gelu_accurate',
43 | 'GradMultiply',
44 | 'Highway',
45 | 'LayerNorm',
46 | 'LearnedPositionalEmbedding',
47 | 'LightweightConv1dTBC',
48 | 'LightweightConv',
49 | 'LinearizedConvolution',
50 | 'LogSumExpMoE',
51 | 'MeanPoolGatingNetwork',
52 | 'MultiheadAttention',
53 | 'PositionalEmbedding',
54 | 'ScalarBias',
55 | 'SinusoidalPositionalEmbedding',
56 | 'TransformerSentenceEncoderLayer',
57 | 'TransformerSentenceEncoder',
58 | 'TransformerDecoderLayer',
59 | 'TransformerEncoderLayer',
60 | 'VGGBlock',
61 | 'unfold1d',
62 | ]
63 |
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/adaptive_input.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/adaptive_input.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/adaptive_softmax.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/adaptive_softmax.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/beamable_mm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/beamable_mm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/character_token_embedder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/character_token_embedder.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/conv_tbc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/conv_tbc.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/dynamic_convolution.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/dynamic_convolution.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/gelu.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/gelu.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/grad_multiply.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/grad_multiply.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/highway.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/highway.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/layer_norm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/layer_norm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/learned_positional_embedding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/learned_positional_embedding.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/lightweight_convolution.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/lightweight_convolution.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/linearized_convolution.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/linearized_convolution.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/logsumexp_moe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/logsumexp_moe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/mean_pool_gating_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/mean_pool_gating_network.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/multihead_attention.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/multihead_attention.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/positional_embedding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/positional_embedding.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/scalar_bias.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/scalar_bias.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/transformer_layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/transformer_layer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/unfold.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/unfold.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/__pycache__/vggblock.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/modules/__pycache__/vggblock.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/modules/beamable_mm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class BeamableMM(nn.Module):
11 | """This module provides an optimized MM for beam decoding with attention.
12 |
13 | It leverage the fact that the source-side of the input is replicated beam
14 | times and the target-side of the input is of width one. This layer speeds up
15 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
16 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
17 | """
18 | def __init__(self, beam_size=None):
19 | super(BeamableMM, self).__init__()
20 | self.beam_size = beam_size
21 |
22 | def forward(self, input1, input2):
23 | if (
24 | not self.training and # test mode
25 | self.beam_size is not None and # beam size is set
26 | input1.dim() == 3 and # only support batched input
27 | input1.size(1) == 1 # single time step update
28 | ):
29 | bsz, beam = input1.size(0), self.beam_size
30 |
31 | # bsz x 1 x nhu --> bsz/beam x beam x nhu
32 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
33 |
34 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
35 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
36 |
37 | # use non batched operation if bsz = beam
38 | if input1.size(0) == 1:
39 | output = torch.mm(input1[0, :, :], input2[0, :, :])
40 | else:
41 | output = input1.bmm(input2)
42 | return output.view(bsz, 1, -1)
43 | else:
44 | return input1.bmm(input2)
45 |
46 | def set_beam_size(self, beam_size):
47 | self.beam_size = beam_size
48 |
--------------------------------------------------------------------------------
/fairseq/modules/conv_tbc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from torch.nn.modules.utils import _single
8 |
9 |
10 | class ConvTBC(torch.nn.Module):
11 | """1D convolution over an input of shape (time x batch x channel)
12 |
13 | The implementation uses gemm to perform the convolution. This implementation
14 | is faster than cuDNN for small kernel sizes.
15 | """
16 | def __init__(self, in_channels, out_channels, kernel_size, padding=0):
17 | super(ConvTBC, self).__init__()
18 | self.in_channels = in_channels
19 | self.out_channels = out_channels
20 | self.kernel_size = _single(kernel_size)
21 | self.padding = _single(padding)
22 |
23 | self.weight = torch.nn.Parameter(torch.Tensor(
24 | self.kernel_size[0], in_channels, out_channels))
25 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
26 |
27 | def forward(self, input):
28 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0])
29 |
30 | def __repr__(self):
31 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
32 | ', padding={padding}')
33 | if self.bias is None:
34 | s += ', bias=False'
35 | s += ')'
36 | return s.format(name=self.__class__.__name__, **self.__dict__)
37 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .dynamicconv_layer import DynamicconvLayer # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | std::vector dynamicconv_cuda_forward(
12 | at::Tensor input,
13 | at::Tensor filters,
14 | int padding_l);
15 |
16 | std::vector dynamicconv_cuda_backward(
17 | at::Tensor gradOutput,
18 | int padding_l,
19 | at::Tensor input,
20 | at::Tensor filters);
21 |
22 |
23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
26 |
27 | std::vector dynamicconv_forward(
28 | at::Tensor input,
29 | at::Tensor filters,
30 | int padding_l) {
31 |
32 | CHECK_INPUT(input);
33 | CHECK_INPUT(filters);
34 |
35 | return dynamicconv_cuda_forward(input, filters,
36 | padding_l);
37 | }
38 |
39 | std::vector dynamicconv_backward(
40 | at::Tensor gradOutput,
41 | int padding_l,
42 | at::Tensor input,
43 | at::Tensor filters) {
44 |
45 | CHECK_INPUT(gradOutput);
46 | CHECK_INPUT(input);
47 | CHECK_INPUT(filters);
48 |
49 | return dynamicconv_cuda_backward(gradOutput, padding_l,
50 | input, filters);
51 | }
52 |
53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
54 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)");
55 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)");
56 | }
57 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #include
23 | #include
24 | #include
25 |
26 | #define SHFL_MASK 0xffffffff
27 |
28 | template
29 | __global__
30 | void dynamicconv_forward_kernel(const scalar_t* input,
31 | const scalar_t* weight,
32 | int minibatch,
33 | int sequenceLength,
34 | int numFeatures,
35 | int numFiltersInBlock,
36 | int numHeads,
37 | scalar_t* output);
38 |
39 | template
40 | __global__
41 | void dynamicconv_backward_kernel(
42 | const scalar_t* gradOutput, // B * C * T
43 | const scalar_t* input, // B * C * T
44 | const scalar_t* weight,
45 | int minibatch,
46 | int sequenceLength,
47 | int numFeatures,
48 | int numFiltersInBlock,
49 | int numHeads,
50 | scalar_t* gradWeight,
51 | scalar_t* gradInput); // B * H * k * T
52 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | std::vector dynamicconv_cpu_forward(
5 | float* input,
6 | float* filters,
7 | int padding_l);
8 |
9 | std::vector dynamicconv_cpu_backward(
10 | float* gradOutput,
11 | int padding_l,
12 | float* input,
13 | float* filters);
14 |
15 | std::vector dynamicconv_forward(
16 | float* input,
17 | float* filters,
18 | int padding_l) {
19 |
20 | return dynamicconv_cpu_forward(input, filters, padding_l);
21 | }
22 |
23 | std::vector dynamicconv_backward(
24 | float* gradOutput,
25 | int padding_l,
26 | float* input,
27 | float* filters) {
28 |
29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters);
30 | }
31 |
32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)");
34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)");
35 | }
36 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
9 |
10 | setup(
11 | name='dynamicconv_layer',
12 | ext_modules=[
13 | CUDAExtension(
14 | name='dynamicconv_cuda',
15 | sources=[
16 | 'dynamicconv_cuda.cpp',
17 | 'dynamicconv_cuda_kernel.cu',
18 | ],
19 | ),
20 | ],
21 | cmdclass={
22 | 'build_ext': BuildExtension
23 | })
24 |
--------------------------------------------------------------------------------
/fairseq/modules/gelu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs
8 | """
9 |
10 | import math
11 |
12 | import torch
13 |
14 |
15 | def gelu_accurate(x):
16 | if not hasattr(gelu_accurate, "_a"):
17 | gelu_accurate._a = math.sqrt(2 / math.pi)
18 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
19 |
20 |
21 | def gelu(x: torch.Tensor) -> torch.Tensor:
22 | if hasattr(torch.nn.functional, 'gelu'):
23 | return torch.nn.functional.gelu(x.float()).type_as(x)
24 | else:
25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
26 |
--------------------------------------------------------------------------------
/fairseq/modules/grad_multiply.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | class GradMultiply(torch.autograd.Function):
10 | @staticmethod
11 | def forward(ctx, x, scale):
12 | ctx.scale = scale
13 | res = x.new(x)
14 | return res
15 |
16 | @staticmethod
17 | def backward(ctx, grad):
18 | return grad * ctx.scale, None
19 |
--------------------------------------------------------------------------------
/fairseq/modules/highway.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from torch import nn
9 |
10 |
11 | class Highway(torch.nn.Module):
12 | """
13 | A `Highway layer `_.
14 | Adopted from the AllenNLP implementation.
15 | """
16 |
17 | def __init__(
18 | self,
19 | input_dim: int,
20 | num_layers: int = 1
21 | ):
22 | super(Highway, self).__init__()
23 | self.input_dim = input_dim
24 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2)
25 | for _ in range(num_layers)])
26 | self.activation = nn.ReLU()
27 |
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | for layer in self.layers:
32 | # As per comment in AllenNLP:
33 | # We should bias the highway layer to just carry its input forward. We do that by
34 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to
35 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half
36 | # of the bias vector in each Linear layer.
37 | nn.init.constant_(layer.bias[self.input_dim:], 1)
38 |
39 | nn.init.constant_(layer.bias[:self.input_dim], 0)
40 | nn.init.xavier_normal_(layer.weight)
41 |
42 | def forward(
43 | self,
44 | x: torch.Tensor
45 | ):
46 | for layer in self.layers:
47 | projection = layer(x)
48 | proj_x, gate = projection.chunk(2, dim=-1)
49 | proj_x = self.activation(proj_x)
50 | gate = torch.sigmoid(gate)
51 | x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
52 | return x
53 |
--------------------------------------------------------------------------------
/fairseq/modules/layer_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
10 | if not export and torch.cuda.is_available():
11 | try:
12 | from apex.normalization import FusedLayerNorm
13 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
14 | except ImportError:
15 | pass
16 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
17 |
--------------------------------------------------------------------------------
/fairseq/modules/learned_positional_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 | from fairseq import utils
9 |
10 |
11 | class LearnedPositionalEmbedding(nn.Embedding):
12 | """
13 | This module learns positional embeddings up to a fixed maximum size.
14 | Padding ids are ignored by either offsetting based on padding_idx
15 | or by setting padding_idx to None and ensuring that the appropriate
16 | position ids are passed to the forward function.
17 | """
18 |
19 | def __init__(
20 | self,
21 | num_embeddings: int,
22 | embedding_dim: int,
23 | padding_idx: int,
24 | ):
25 | super().__init__(num_embeddings, embedding_dim, padding_idx)
26 | self.onnx_trace = False
27 |
28 | def forward(self, input, incremental_state=None, positions=None):
29 | """Input is expected to be of size [bsz x seqlen]."""
30 | assert (
31 | (positions is None) or (self.padding_idx is None)
32 | ), "If positions is pre-computed then padding_idx should not be set."
33 |
34 | if positions is None:
35 | if incremental_state is not None:
36 | # positions is the same for every token when decoding a single step
37 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX
38 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1)))
39 | else:
40 | positions = utils.make_positions(
41 | input, self.padding_idx, onnx_trace=self.onnx_trace,
42 | )
43 | return super().forward(positions)
44 |
45 | def max_positions(self):
46 | """Maximum number of supported positions."""
47 | if self.padding_idx is not None:
48 | return self.num_embeddings - self.padding_idx - 1
49 | else:
50 | return self.num_embeddings
51 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .lightconv_layer import LightconvLayer # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/lightconv_cuda.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | std::vector lightconv_cuda_forward(
12 | at::Tensor input,
13 | at::Tensor filters,
14 | int padding_l);
15 |
16 | std::vector lightconv_cuda_backward(
17 | at::Tensor gradOutput,
18 | int padding_l,
19 | at::Tensor input,
20 | at::Tensor filters);
21 |
22 |
23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
26 |
27 | std::vector lightconv_forward(
28 | at::Tensor input,
29 | at::Tensor filters,
30 | int padding_l) {
31 |
32 | CHECK_INPUT(input);
33 | CHECK_INPUT(filters);
34 |
35 | return lightconv_cuda_forward(input, filters, padding_l);
36 | }
37 |
38 | std::vector lightconv_backward(
39 | at::Tensor gradOutput,
40 | int padding_l,
41 | at::Tensor input,
42 | at::Tensor filters) {
43 |
44 | CHECK_INPUT(gradOutput);
45 | CHECK_INPUT(input);
46 | CHECK_INPUT(filters);
47 |
48 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters);
49 | }
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)");
53 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)");
54 | }
55 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
9 |
10 | setup(
11 | name='lightconv_layer',
12 | ext_modules=[
13 | CUDAExtension('lightconv_cuda', [
14 | 'lightconv_cuda.cpp',
15 | 'lightconv_cuda_kernel.cu',
16 | ]),
17 | ],
18 | cmdclass={
19 | 'build_ext': BuildExtension
20 | })
21 |
--------------------------------------------------------------------------------
/fairseq/modules/logsumexp_moe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | class LogSumExpMoE(torch.autograd.Function):
10 | """Standard LogSumExp forward pass, but use *posterior* for the backward.
11 |
12 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
13 | (Shen et al., 2019) `_.
14 | """
15 |
16 | @staticmethod
17 | def forward(ctx, logp, posterior, dim=-1):
18 | ctx.save_for_backward(posterior)
19 | ctx.dim = dim
20 | return torch.logsumexp(logp, dim=dim)
21 |
22 | @staticmethod
23 | def backward(ctx, grad_output):
24 | posterior, = ctx.saved_tensors
25 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
26 | return grad_logp, None, None
27 |
--------------------------------------------------------------------------------
/fairseq/modules/mean_pool_gating_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 |
10 | class MeanPoolGatingNetwork(torch.nn.Module):
11 | """A simple mean-pooling gating network for selecting experts.
12 |
13 | This module applies mean pooling over an encoder's output and returns
14 | reponsibilities for each expert. The encoder format is expected to match
15 | :class:`fairseq.models.transformer.TransformerEncoder`.
16 | """
17 |
18 | def __init__(self, embed_dim, num_experts, dropout=None):
19 | super().__init__()
20 | self.embed_dim = embed_dim
21 | self.num_experts = num_experts
22 |
23 | self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
24 | self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
25 | self.fc2 = torch.nn.Linear(embed_dim, num_experts)
26 |
27 | def forward(self, encoder_out):
28 | if not (
29 | hasattr(encoder_out, 'encoder_out')
30 | and hasattr(encoder_out, 'encoder_padding_mask')
31 | and encoder_out.encoder_out.size(2) == self.embed_dim
32 | ):
33 | raise ValueError('Unexpected format for encoder_out')
34 |
35 | # mean pooling over time
36 | encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
37 | encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
38 | if encoder_padding_mask is not None:
39 | encoder_out = encoder_out.clone() # required because of transpose above
40 | encoder_out[encoder_padding_mask] = 0
41 | ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True)
42 | x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
43 | else:
44 | x = torch.mean(encoder_out, dim=1)
45 |
46 | x = torch.tanh(self.fc1(x))
47 | if self.dropout is not None:
48 | x = self.dropout(x)
49 | x = self.fc2(x)
50 | return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
51 |
--------------------------------------------------------------------------------
/fairseq/modules/positional_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 | from .learned_positional_embedding import LearnedPositionalEmbedding
9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
10 |
11 |
12 | def PositionalEmbedding(
13 | num_embeddings: int,
14 | embedding_dim: int,
15 | padding_idx: int,
16 | learned: bool = False,
17 | ):
18 | if learned:
19 | # if padding_idx is specified then offset the embedding ids by
20 | # this index and adjust num_embeddings appropriately
21 | # TODO: The right place for this offset would be inside
22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
23 | if padding_idx is not None:
24 | num_embeddings = num_embeddings + padding_idx + 1
25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
27 | if padding_idx is not None:
28 | nn.init.constant_(m.weight[padding_idx], 0)
29 | else:
30 | m = SinusoidalPositionalEmbedding(
31 | embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
32 | )
33 | return m
34 |
--------------------------------------------------------------------------------
/fairseq/modules/scalar_bias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | #
6 |
7 | import torch
8 |
9 |
10 | class ScalarBias(torch.autograd.Function):
11 | """
12 | Adds a vector of scalars, used in self-attention mechanism to allow
13 | the model to optionally attend to this vector instead of the past
14 | """
15 |
16 | @staticmethod
17 | def forward(ctx, input, dim, bias_init):
18 | size = list(input.size())
19 | size[dim] += 1
20 | output = input.new(*size).fill_(bias_init)
21 | output.narrow(dim, 1, size[dim] - 1).copy_(input)
22 | ctx.dim = dim
23 | return output
24 |
25 | @staticmethod
26 | def backward(ctx, grad):
27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
28 |
29 |
30 | def scalar_bias(input, dim, bias_init=0):
31 | return ScalarBias.apply(input, dim, bias_init)
32 |
--------------------------------------------------------------------------------
/fairseq/modules/sparse_transformer_sentence_encoder_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.modules import TransformerSentenceEncoderLayer
7 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
8 |
9 |
10 | class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
11 | """
12 | Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention)
13 | """
14 |
15 | def __init__(
16 | self,
17 | embedding_dim: int = 768,
18 | ffn_embedding_dim: int = 3072,
19 | num_attention_heads: int = 8,
20 | dropout: float = 0.1,
21 | attention_dropout: float = 0.1,
22 | activation_dropout: float = 0.1,
23 | activation_fn: str = 'relu',
24 | add_bias_kv: bool = False,
25 | add_zero_attn: bool = False,
26 | export: bool = False,
27 | is_bidirectional: bool = True,
28 | stride: int = 32,
29 | expressivity: int = 8,
30 | ) -> None:
31 |
32 | super().__init__(
33 | embedding_dim, ffn_embedding_dim, num_attention_heads, dropout,
34 | attention_dropout, activation_dropout, activation_fn, add_bias_kv,
35 | add_zero_attn, export
36 | )
37 |
38 | self.self_attn = SparseMultiheadAttention(
39 | self.embedding_dim,
40 | num_attention_heads,
41 | dropout=attention_dropout,
42 | add_bias_kv=add_bias_kv,
43 | add_zero_attn=add_zero_attn,
44 | self_attention=True,
45 | is_bidirectional=is_bidirectional,
46 | stride=stride,
47 | expressivity=expressivity,
48 | )
49 |
--------------------------------------------------------------------------------
/fairseq/modules/unfold.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn.functional as F
7 |
8 |
9 | def unfold1d(x, kernel_size, padding_l, pad_value=0):
10 | '''unfold T x B x C to T x B x C x K'''
11 | if kernel_size > 1:
12 | T, B, C = x.size()
13 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value)
14 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C))
15 | else:
16 | x = x.unsqueeze(3)
17 | return x
18 |
--------------------------------------------------------------------------------
/fairseq/optim/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from fairseq import registry
10 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer
11 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
12 | from fairseq.optim.bmuf import FairseqBMUF # noqa
13 |
14 |
15 | __all__ = [
16 | 'FairseqOptimizer',
17 | 'FP16Optimizer',
18 | 'MemoryEfficientFP16Optimizer',
19 | ]
20 |
21 |
22 | build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
23 | '--optimizer',
24 | base_class=FairseqOptimizer,
25 | default='nag',
26 | )
27 |
28 |
29 | # automatically import any Python files in the optim/ directory
30 | for file in os.listdir(os.path.dirname(__file__)):
31 | if file.endswith('.py') and not file.startswith('_'):
32 | module = file[:file.find('.py')]
33 | importlib.import_module('fairseq.optim.' + module)
34 |
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/adadelta.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/adadelta.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/adafactor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/adafactor.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/adagrad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/adagrad.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/adam.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/adam.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/adamax.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/adamax.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/bmuf.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/bmuf.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/fairseq_optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/fairseq_optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/fp16_optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/fp16_optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/nag.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/nag.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/__pycache__/sgd.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/__pycache__/sgd.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/adadelta.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import FairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer('adadelta')
12 | class Adadelta(FairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
22 | help='coefficient used for computing a running average of squared gradients')
23 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
24 | help='term added to the denominator to improve numerical stability')
25 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
26 | help='weight decay')
27 | parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps')
28 | # fmt: on
29 |
30 | @property
31 | def optimizer_config(self):
32 | """
33 | Return a kwarg dictionary that will be used to override optimizer
34 | args stored in checkpoints. This allows us to load a checkpoint and
35 | resume training using a different set of optimizer args, e.g., with a
36 | different learning rate.
37 | """
38 | return {
39 | 'lr': self.args.lr[0],
40 | 'rho': self.args.adadelta_rho,
41 | 'eps': self.args.adadelta_eps,
42 | 'weight_decay': self.args.weight_decay,
43 | }
44 |
--------------------------------------------------------------------------------
/fairseq/optim/adagrad.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import FairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer('adagrad')
12 | class Adagrad(FairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22 | help='weight decay')
23 | # fmt: on
24 |
25 | @property
26 | def optimizer_config(self):
27 | """
28 | Return a kwarg dictionary that will be used to override optimizer
29 | args stored in checkpoints. This allows us to load a checkpoint and
30 | resume training using a different set of optimizer args, e.g., with a
31 | different learning rate.
32 | """
33 | return {
34 | 'lr': self.args.lr[0],
35 | 'weight_decay': self.args.weight_decay,
36 | }
37 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from fairseq import registry
10 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler
11 |
12 |
13 | build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry(
14 | '--lr-scheduler',
15 | base_class=FairseqLRScheduler,
16 | default='fixed',
17 | )
18 |
19 | # automatically import any Python files in the optim/lr_scheduler/ directory
20 | for file in os.listdir(os.path.dirname(__file__)):
21 | if file.endswith('.py') and not file.startswith('_'):
22 | module = file[:file.find('.py')]
23 | importlib.import_module('fairseq.optim.lr_scheduler.' + module)
24 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .. import FairseqOptimizer
7 |
8 |
9 | class FairseqLRScheduler(object):
10 |
11 | def __init__(self, args, optimizer):
12 | super().__init__()
13 | if not isinstance(optimizer, FairseqOptimizer):
14 | raise ValueError('optimizer must be an instance of FairseqOptimizer')
15 | self.args = args
16 | self.optimizer = optimizer
17 | self.best = None
18 |
19 | @staticmethod
20 | def add_args(parser):
21 | """Add arguments to the parser for this LR scheduler."""
22 | pass
23 |
24 | def state_dict(self):
25 | """Return the LR scheduler state dict."""
26 | return {'best': self.best}
27 |
28 | def load_state_dict(self, state_dict):
29 | """Load an LR scheduler state dict."""
30 | self.best = state_dict['best']
31 |
32 | def step(self, epoch, val_loss=None):
33 | """Update the learning rate at the end of the given epoch."""
34 | if val_loss is not None:
35 | if self.best is None:
36 | self.best = val_loss
37 | else:
38 | self.best = min(self.best, val_loss)
39 |
40 | def step_update(self, num_updates):
41 | """Update the learning rate after each update."""
42 | return self.optimizer.get_lr()
43 |
--------------------------------------------------------------------------------
/fairseq/optim/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import FairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer('sgd')
12 | class SGD(FairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
22 | help='momentum factor')
23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
24 | help='weight decay')
25 | # fmt: on
26 |
27 | @property
28 | def optimizer_config(self):
29 | """
30 | Return a kwarg dictionary that will be used to override optimizer
31 | args stored in checkpoints. This allows us to load a checkpoint and
32 | resume training using a different set of optimizer args, e.g., with a
33 | different learning rate.
34 | """
35 | return {
36 | 'lr': self.args.lr[0],
37 | 'momentum': self.args.momentum,
38 | 'weight_decay': self.args.weight_decay,
39 | }
40 |
--------------------------------------------------------------------------------
/fairseq/pdb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import multiprocessing
7 | import os
8 | import pdb
9 | import sys
10 |
11 |
12 | __all__ = ['set_trace']
13 |
14 |
15 | _stdin = [None]
16 | _stdin_lock = multiprocessing.Lock()
17 | try:
18 | _stdin_fd = sys.stdin.fileno()
19 | except Exception:
20 | _stdin_fd = None
21 |
22 |
23 | class MultiprocessingPdb(pdb.Pdb):
24 | """A Pdb wrapper that works in a multiprocessing environment.
25 |
26 | Usage: `from fairseq import pdb; pdb.set_trace()`
27 | """
28 |
29 | def __init__(self):
30 | pdb.Pdb.__init__(self, nosigint=True)
31 |
32 | def _cmdloop(self):
33 | stdin_bak = sys.stdin
34 | with _stdin_lock:
35 | try:
36 | if _stdin_fd is not None:
37 | if not _stdin[0]:
38 | _stdin[0] = os.fdopen(_stdin_fd)
39 | sys.stdin = _stdin[0]
40 | self.cmdloop()
41 | finally:
42 | sys.stdin = stdin_bak
43 |
44 |
45 | def set_trace():
46 | pdb = MultiprocessingPdb()
47 | pdb.set_trace(sys._getframe().f_back)
48 |
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/audio_pretraining.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/audio_pretraining.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/back_distil.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/back_distil.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/denoising.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/denoising.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/fairseq_task.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/fairseq_task.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/language_modeling.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/language_modeling.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/masked_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/masked_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/multilingual_translation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/multilingual_translation.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/semisupervised_translation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/semisupervised_translation.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/sentence_prediction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/sentence_prediction.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/sentence_ranking.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/sentence_ranking.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/translation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/translation.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/translation_lev.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/translation_lev.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/__pycache__/translation_moe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq/tasks/__pycache__/translation_moe.cpython-37.pyc
--------------------------------------------------------------------------------
/fairseq/tasks/audio_pretraining.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 |
8 | from fairseq.data import FileAudioDataset
9 | from . import FairseqTask, register_task
10 |
11 |
12 | @register_task('audio_pretraining')
13 | class AudioPretrainingTask(FairseqTask):
14 | """
15 |
16 | """
17 |
18 | @staticmethod
19 | def add_args(parser):
20 | """Add task-specific arguments to the parser."""
21 | parser.add_argument('data', help='path to data directory')
22 | parser.add_argument('--sample-rate', default=16000, type=int,
23 | help='target sample rate. audio files will be up/down sampled to this rate')
24 | parser.add_argument('--max-sample-size', default=None, type=int,
25 | help='max sample size to crop to for batching. default = min sample length')
26 | parser.add_argument('--min-sample-size', default=None, type=int,
27 | help='min sample size to crop to for batching. default = same as --max-sample-size')
28 |
29 | def __init__(self, args):
30 | super().__init__(args)
31 |
32 | @classmethod
33 | def setup_task(cls, args, **kwargs):
34 | """Setup the task (e.g., load dictionaries).
35 |
36 | Args:
37 | args (argparse.Namespace): parsed command-line arguments
38 | """
39 | return cls(args)
40 |
41 | def load_dataset(self, split, **kwargs):
42 | """Load a given dataset split.
43 |
44 | Args:
45 | split (str): name of the split (e.g., train, valid, test)
46 | """
47 |
48 | manifest = os.path.join(self.args.data, '{}.tsv'.format(split))
49 | self.datasets[split] = FileAudioDataset(manifest,
50 | sample_rate=self.args.sample_rate,
51 | max_sample_size=self.args.max_sample_size,
52 | min_sample_size=self.args.min_sample_size)
53 |
54 | @property
55 | def target_dictionary(self):
56 | """Return the :class:`~fairseq.data.Dictionary` for the language
57 | model."""
58 | return None
59 |
--------------------------------------------------------------------------------
/fairseq/tasks/translation_from_pretrained_xlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
7 | from fairseq.tasks.translation import TranslationTask
8 |
9 | from . import register_task
10 |
11 |
12 | @register_task("translation_from_pretrained_xlm")
13 | class TranslationFromPretrainedXLMTask(TranslationTask):
14 | """
15 | Same as TranslationTask except use the MaskedLMDictionary class so that
16 | we can load data that was binarized with the MaskedLMDictionary class.
17 |
18 | This task should be used for the entire training pipeline when we want to
19 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
20 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation
21 | of that trained model.
22 | """
23 |
24 | @classmethod
25 | def load_dictionary(cls, filename):
26 | """Load the masked LM dictionary from the filename
27 |
28 | Args:
29 | filename (str): the filename
30 | """
31 | return MaskedLMDictionary.load(filename)
32 |
--------------------------------------------------------------------------------
/fairseq/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import re
7 |
8 | SPACE_NORMALIZER = re.compile(r"\s+")
9 |
10 |
11 | def tokenize_line(line):
12 | line = SPACE_NORMALIZER.sub(" ", line)
13 | line = line.strip()
14 | return line.split()
15 |
--------------------------------------------------------------------------------
/fairseq_cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/fairseq_cli/__init__.py
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import functools
7 |
8 | from fairseq.hub_utils import BPEHubInterface as bpe # noqa
9 | from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa
10 | from fairseq.models import MODEL_REGISTRY
11 |
12 |
13 | dependencies = [
14 | 'numpy',
15 | 'regex',
16 | 'requests',
17 | 'torch',
18 | ]
19 |
20 |
21 | # torch.hub doesn't build Cython components, so if they are not found then try
22 | # to build them here
23 | try:
24 | import fairseq.data.token_block_utils_fast
25 | except (ImportError, ModuleNotFoundError):
26 | try:
27 | import cython
28 | import os
29 | from setuptools import sandbox
30 | sandbox.run_setup(
31 | os.path.join(os.path.dirname(__file__), 'setup.py'),
32 | ['build_ext', '--inplace'],
33 | )
34 | except (ImportError, ModuleNotFoundError):
35 | print(
36 | 'Unable to build Cython components. Please make sure Cython is '
37 | 'installed if the torch.hub model you are loading depends on it.'
38 | )
39 |
40 |
41 | for _model_type, _cls in MODEL_REGISTRY.items():
42 | for model_name in _cls.hub_models().keys():
43 | globals()[model_name] = functools.partial(
44 | _cls.from_pretrained,
45 | model_name,
46 | )
47 | # to simplify the interface we only expose named models
48 | # globals()[_model_type] = _cls.from_pretrained
49 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/compare_namespaces.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Helper script to compare two argparse.Namespace objects."""
3 |
4 | from argparse import Namespace # noqa
5 |
6 |
7 | def main():
8 |
9 | ns1 = eval(input('Namespace 1: '))
10 | ns2 = eval(input('Namespace 2: '))
11 |
12 | def keys(ns):
13 | ks = set()
14 | for k in dir(ns):
15 | if not k.startswith('_'):
16 | ks.add(k)
17 | return ks
18 |
19 | k1 = keys(ns1)
20 | k2 = keys(ns2)
21 |
22 | def print_keys(ks, ns1, ns2=None):
23 | for k in ks:
24 | if ns2 is None:
25 | print('{}\t{}'.format(k, getattr(ns1, k, None)))
26 | else:
27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None)))
28 |
29 | print('Keys unique to namespace 1:')
30 | print_keys(k1 - k2, ns1)
31 | print()
32 |
33 | print('Keys unique to namespace 2:')
34 | print_keys(k2 - k1, ns2)
35 | print()
36 |
37 | print('Overlapping keys with different values:')
38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')]
39 | print_keys(ks, ns1, ns2)
40 | print()
41 |
42 |
43 | if __name__ == '__main__':
44 | main()
45 |
--------------------------------------------------------------------------------
/scripts/compound_split_bleu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 1 ]; then
4 | echo "usage: $0 GENERATE_PY_OUTPUT"
5 | exit 1
6 | fi
7 |
8 | GEN=$1
9 |
10 | SYS=$GEN.sys
11 | REF=$GEN.ref
12 |
13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
14 | echo "not done generating"
15 | exit
16 | fi
17 |
18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
20 | fairseq-score --sys $SYS --ref $REF
21 |
--------------------------------------------------------------------------------
/scripts/convert_dictionary.lua:
--------------------------------------------------------------------------------
1 | -- Copyright (c) Facebook, Inc. and its affiliates.
2 | --
3 | -- This source code is licensed under the MIT license found in the
4 | -- LICENSE file in the root directory of this source tree.
5 | --
6 | -- Usage: convert_dictionary.lua
7 | require 'fairseq'
8 | require 'torch'
9 | require 'paths'
10 |
11 | if #arg < 1 then
12 | print('usage: convert_dictionary.lua ')
13 | os.exit(1)
14 | end
15 | if not paths.filep(arg[1]) then
16 | print('error: file does not exit: ' .. arg[1])
17 | os.exit(1)
18 | end
19 |
20 | dict = torch.load(arg[1])
21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt')
22 | assert(dst:match('.txt$'))
23 |
24 | f = io.open(dst, 'w')
25 | for idx, symbol in ipairs(dict.index_to_symbol) do
26 | if idx > dict.cutoff then
27 | break
28 | end
29 | f:write(symbol)
30 | f:write(' ')
31 | f:write(dict.index_to_freq[idx])
32 | f:write('\n')
33 | end
34 | f:close()
35 |
--------------------------------------------------------------------------------
/scripts/count_docs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Count the number of documents and average number of lines and tokens per
8 | document in a large file. Documents should be separated by a single empty line.
9 | """
10 |
11 | import argparse
12 | import gzip
13 | import sys
14 |
15 | import numpy as np
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('input')
21 | parser.add_argument('--gzip', action='store_true')
22 | args = parser.parse_args()
23 |
24 | def gopen():
25 | if args.gzip:
26 | return gzip.open(args.input, 'r')
27 | else:
28 | return open(args.input, 'r', encoding='utf-8')
29 |
30 | num_lines = []
31 | num_toks = []
32 | with gopen() as h:
33 | num_docs = 1
34 | num_lines_in_doc = 0
35 | num_toks_in_doc = 0
36 | for i, line in enumerate(h):
37 | if len(line.strip()) == 0: # empty line indicates new document
38 | num_docs += 1
39 | num_lines.append(num_lines_in_doc)
40 | num_toks.append(num_toks_in_doc)
41 | num_lines_in_doc = 0
42 | num_toks_in_doc = 0
43 | else:
44 | num_lines_in_doc += 1
45 | num_toks_in_doc += len(line.rstrip().split())
46 | if i % 1000000 == 0:
47 | print(i, file=sys.stderr, end="", flush=True)
48 | elif i % 100000 == 0:
49 | print(".", file=sys.stderr, end="", flush=True)
50 | print(file=sys.stderr, flush=True)
51 |
52 | print("found {} docs".format(num_docs))
53 | print("average num lines per doc: {}".format(np.mean(num_lines)))
54 | print("average num toks per doc: {}".format(np.mean(num_toks)))
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/scripts/read_binarized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 |
9 | from fairseq.data import data_utils, Dictionary, indexed_dataset
10 |
11 |
12 | def get_parser():
13 | parser = argparse.ArgumentParser(
14 | description='writes text from binarized file to stdout')
15 | # fmt: off
16 | parser.add_argument('--dataset-impl', help='dataset implementation',
17 | choices=indexed_dataset.get_available_dataset_impl())
18 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
19 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
20 | # fmt: on
21 |
22 | return parser
23 |
24 |
25 | def main():
26 | parser = get_parser()
27 | args = parser.parse_args()
28 |
29 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None
30 | dataset = data_utils.load_indexed_dataset(
31 | args.input,
32 | dictionary,
33 | dataset_impl=args.dataset_impl,
34 | default='lazy',
35 | )
36 |
37 | for tensor_line in dataset:
38 | if dictionary is None:
39 | line = ' '.join([str(int(x)) for x in tensor_line])
40 | else:
41 | line = dictionary.string(tensor_line)
42 |
43 | print(line)
44 |
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/scripts/sacrebleu_pregen.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 4 ]; then
4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
5 | exit 1
6 | fi
7 |
8 | TESTSET=$1
9 | SRCLANG=$2
10 | TGTLANG=$3
11 |
12 | GEN=$4
13 |
14 | echo 'Cloning Moses github repository (for tokenization scripts)...'
15 | git clone https://github.com/moses-smt/mosesdecoder.git
16 |
17 | SCRIPTS=mosesdecoder/scripts
18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
19 |
20 | grep ^H $GEN \
21 | | sed 's/^H\-//' \
22 | | sort -n -k 1 \
23 | | cut -f 3 \
24 | | perl $DETOKENIZER -l $TGTLANG \
25 | | sed "s/ - /-/g" \
26 | > $GEN.sorted.detok
27 |
28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok
29 |
--------------------------------------------------------------------------------
/scripts/shard_docs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Split a large file into shards while respecting document boundaries. Documents
8 | should be separated by a single empty line.
9 | """
10 |
11 | import argparse
12 | import contextlib
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('input')
18 | parser.add_argument('--num-shards', type=int)
19 | args = parser.parse_args()
20 |
21 | assert args.num_shards is not None and args.num_shards > 1
22 |
23 | with open(args.input, 'r', encoding='utf-8') as h:
24 | with contextlib.ExitStack() as stack:
25 | outputs = [
26 | stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8"))
27 | for i in range(args.num_shards)
28 | ]
29 |
30 | doc = []
31 | first_doc = [True]*args.num_shards
32 | def output_doc(i):
33 | if not first_doc[i]:
34 | outputs[i].write("\n")
35 | first_doc[i] = False
36 | for line in doc:
37 | outputs[i].write(line)
38 | doc.clear()
39 |
40 | num_docs = 0
41 | for line in h:
42 | if line.strip() == "": # empty line indicates new document
43 | output_doc(num_docs % args.num_shards)
44 | num_docs += 1
45 | else:
46 | doc.append(line)
47 | output_doc(num_docs % args.num_shards)
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
--------------------------------------------------------------------------------
/scripts/spm_decode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import argparse
11 |
12 | import sentencepiece as spm
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--model", required=True,
18 | help="sentencepiece model to use for decoding")
19 | parser.add_argument("--input", required=True, help="input file to decode")
20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
21 | args = parser.parse_args()
22 |
23 | sp = spm.SentencePieceProcessor()
24 | sp.Load(args.model)
25 |
26 | if args.input_format == "piece":
27 | def decode(l):
28 | return "".join(sp.DecodePieces(l))
29 | elif args.input_format == "id":
30 | def decode(l):
31 | return "".join(sp.DecodeIds(l))
32 | else:
33 | raise NotImplementedError
34 |
35 | def tok2int(tok):
36 | # remap reference-side (represented as <>) to 0
37 | return int(tok) if tok != "<>" else 0
38 |
39 | with open(args.input, "r", encoding="utf-8") as h:
40 | for line in h:
41 | print(decode(list(map(tok2int, line.rstrip().split()))))
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/scripts/spm_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import sys
11 |
12 | import sentencepiece as spm
13 |
14 |
15 | if __name__ == "__main__":
16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
17 |
--------------------------------------------------------------------------------
/scripts/wav2vec_manifest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Data pre-processing: build vocabularies and binarize training data.
8 | """
9 |
10 | import argparse
11 | import glob
12 | import os
13 | import soundfile
14 | import random
15 |
16 |
17 | def get_parser():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index')
20 | parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D',
21 | help='percentage of data to use as validation set (between 0 and 1)')
22 | parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory')
23 | parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for')
24 | parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed')
25 | parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG',
26 | help='if set, path must contain this substring for a file to be included in the manifest')
27 | return parser
28 |
29 |
30 | def main(args):
31 | assert args.valid_percent >= 0 and args.valid_percent <= 1.
32 |
33 | dir_path = os.path.realpath(args.root)
34 | search_path = os.path.join(dir_path, '**/*.' + args.ext)
35 | rand = random.Random(args.seed)
36 |
37 | with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open(
38 | os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f:
39 | print(dir_path, file=train_f)
40 | print(dir_path, file=valid_f)
41 |
42 | for fname in glob.iglob(search_path, recursive=True):
43 | file_path = os.path.realpath(fname)
44 |
45 | if args.path_must_contain and args.path_must_contain not in file_path:
46 | continue
47 |
48 | frames = soundfile.info(fname).frames
49 | dest = train_f if rand.random() > args.valid_percent else valid_f
50 | print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest)
51 |
52 |
53 | if __name__ == '__main__':
54 | parser = get_parser()
55 | args = parser.parse_args()
56 | main(args)
57 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/tests/__init__.py
--------------------------------------------------------------------------------
/tests/speech_recognition/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/Knowledge-Inheritance/0d16ff135834ff2cace0b9769b0d3501c2dd5cbe/tests/speech_recognition/__init__.py
--------------------------------------------------------------------------------
/tests/speech_recognition/test_collaters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import unittest
8 |
9 | import numpy as np
10 | import torch
11 | from examples.speech_recognition.data.collaters import Seq2SeqCollater
12 |
13 |
14 | class TestSeq2SeqCollator(unittest.TestCase):
15 | def test_collate(self):
16 |
17 | eos_idx = 1
18 | pad_idx = 0
19 | collater = Seq2SeqCollater(
20 | feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx
21 | )
22 |
23 | # 2 frames in the first sample and 3 frames in the second one
24 | frames1 = np.array([[7, 8], [9, 10]])
25 | frames2 = np.array([[1, 2], [3, 4], [5, 6]])
26 | target1 = np.array([4, 2, 3, eos_idx])
27 | target2 = np.array([3, 2, eos_idx])
28 | sample1 = {"id": 0, "data": [frames1, target1]}
29 | sample2 = {"id": 1, "data": [frames2, target2]}
30 | batch = collater.collate([sample1, sample2])
31 |
32 | # collate sort inputs by frame's length before creating the batch
33 | self.assertTensorEqual(batch["id"], torch.tensor([1, 0]))
34 | self.assertEqual(batch["ntokens"], 7)
35 | self.assertTensorEqual(
36 | batch["net_input"]["src_tokens"],
37 | torch.tensor(
38 | [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]]
39 | ),
40 | )
41 | self.assertTensorEqual(
42 | batch["net_input"]["prev_output_tokens"],
43 | torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]),
44 | )
45 | self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2]))
46 | self.assertTensorEqual(
47 | batch["target"],
48 | torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]),
49 | )
50 | self.assertEqual(batch["nsentences"], 2)
51 |
52 | def assertTensorEqual(self, t1, t2):
53 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
54 | self.assertEqual(t1.ne(t2).long().sum(), 0)
55 |
56 |
57 | if __name__ == "__main__":
58 | unittest.main()
59 |
--------------------------------------------------------------------------------
/tests/speech_recognition/test_cross_entropy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
8 | from .asr_test_base import CrossEntropyCriterionTestBase
9 |
10 |
11 | class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
12 | def setUp(self):
13 | self.criterion_cls = CrossEntropyWithAccCriterion
14 | super().setUp()
15 |
16 | def test_cross_entropy_all_correct(self):
17 | sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
18 | loss, sample_size, logging_output = self.criterion(
19 | self.model, sample, "sum", log_probs=True
20 | )
21 | assert logging_output["correct"] == 20
22 | assert logging_output["total"] == 20
23 | assert logging_output["sample_size"] == 20
24 | assert logging_output["ntokens"] == 20
25 |
26 | def test_cross_entropy_all_wrong(self):
27 | sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
28 | loss, sample_size, logging_output = self.criterion(
29 | self.model, sample, "sum", log_probs=True
30 | )
31 | assert logging_output["correct"] == 0
32 | assert logging_output["total"] == 20
33 | assert logging_output["sample_size"] == 20
34 | assert logging_output["ntokens"] == 20
35 |
--------------------------------------------------------------------------------
/tests/test_character_token_embedder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import unittest
8 |
9 | from fairseq.data import Dictionary
10 | from fairseq.modules import CharacterTokenEmbedder
11 |
12 |
13 | class TestCharacterTokenEmbedder(unittest.TestCase):
14 | def test_character_token_embedder(self):
15 | vocab = Dictionary()
16 | vocab.add_symbol('hello')
17 | vocab.add_symbol('there')
18 |
19 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2)
20 |
21 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
22 | max_len = max(len(s) for s in test_sents)
23 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad())
24 | for i in range(len(test_sents)):
25 | input[i][0] = vocab.eos()
26 | for j in range(len(test_sents[i])):
27 | input[i][j + 1] = vocab.index(test_sents[i][j])
28 | input[i][j + 2] = vocab.eos()
29 | embs = embedder(input)
30 |
31 | assert embs.size() == (len(test_sents), max_len + 2, 5)
32 | self.assertAlmostEqual(embs[0][0], embs[1][0])
33 | self.assertAlmostEqual(embs[0][0], embs[0][-1])
34 | self.assertAlmostEqual(embs[0][1], embs[2][1])
35 | self.assertAlmostEqual(embs[0][3], embs[1][1])
36 |
37 | embs.sum().backward()
38 | assert embedder.char_embeddings.weight.grad is not None
39 |
40 | def assertAlmostEqual(self, t1, t2):
41 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
42 | self.assertLess((t1 - t2).abs().max(), 1e-6)
43 |
44 |
45 | if __name__ == '__main__':
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/tests/test_concat_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 |
8 | import torch
9 | from fairseq.data import LanguagePairDataset, TokenBlockDataset
10 | from fairseq.data.concat_dataset import ConcatDataset
11 | from tests.test_train import mock_dict
12 |
13 |
14 | class TestConcatDataset(unittest.TestCase):
15 | def setUp(self):
16 | d = mock_dict()
17 | tokens_1 = torch.LongTensor([1]).view(1, -1)
18 | tokens_ds1 = TokenBlockDataset(
19 | tokens_1,
20 | sizes=[tokens_1.size(-1)],
21 | block_size=1,
22 | pad=0,
23 | eos=1,
24 | include_targets=False,
25 | )
26 | self.dataset_1 = LanguagePairDataset(
27 | tokens_ds1, tokens_ds1.sizes, d, shuffle=False
28 | )
29 | tokens_2 = torch.LongTensor([2]).view(1, -1)
30 | tokens_ds2 = TokenBlockDataset(
31 | tokens_2,
32 | sizes=[tokens_2.size(-1)],
33 | block_size=1,
34 | pad=0,
35 | eos=1,
36 | include_targets=False,
37 | )
38 | self.dataset_2 = LanguagePairDataset(
39 | tokens_ds2, tokens_ds2.sizes, d, shuffle=False
40 | )
41 |
42 | def test_concat_dataset_basics(self):
43 | d = ConcatDataset(
44 | [self.dataset_1, self.dataset_2]
45 | )
46 | assert(len(d) == 2)
47 | assert(d[0]['source'][0] == 1)
48 | assert(d[1]['source'][0] == 2)
49 |
50 | d = ConcatDataset(
51 | [self.dataset_1, self.dataset_2], sample_ratios=[1, 2]
52 | )
53 | assert(len(d) == 3)
54 | assert(d[0]['source'][0] == 1)
55 | assert(d[1]['source'][0] == 2)
56 | assert(d[2]['source'][0] == 2)
57 |
58 | d = ConcatDataset(
59 | [self.dataset_1, self.dataset_2], sample_ratios=[2, 1]
60 | )
61 | assert(len(d) == 3)
62 | assert(d[0]['source'][0] == 1)
63 | assert(d[1]['source'][0] == 1)
64 | assert(d[2]['source'][0] == 2)
65 |
--------------------------------------------------------------------------------
/tests/test_convtbc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import unittest
8 | from fairseq.modules import ConvTBC
9 | import torch.nn as nn
10 |
11 |
12 | class TestConvTBC(unittest.TestCase):
13 |
14 | def test_convtbc(self):
15 | # ksz, in_channels, out_channels
16 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1)
17 | # out_channels, in_channels, ksz
18 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1)
19 |
20 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
21 | conv_tbc.bias.data.copy_(conv1d.bias.data)
22 |
23 | input_tbc = torch.randn(7, 2, 4, requires_grad=True)
24 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2)
25 | input1d.requires_grad = True
26 |
27 | output_tbc = conv_tbc(input_tbc)
28 | output1d = conv1d(input1d)
29 |
30 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
31 |
32 | grad_tbc = torch.randn(output_tbc.size())
33 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
34 |
35 | output_tbc.backward(grad_tbc)
36 | output1d.backward(grad1d)
37 |
38 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data)
39 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data)
40 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data)
41 |
42 | def assertAlmostEqual(self, t1, t2):
43 | self.assertEqual(t1.size(), t2.size(), "size mismatch")
44 | self.assertLess((t1 - t2).abs().max(), 1e-4)
45 |
46 |
47 | if __name__ == '__main__':
48 | unittest.main()
49 |
--------------------------------------------------------------------------------
/tests/test_dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import tempfile
7 | import unittest
8 |
9 | import torch
10 |
11 | from fairseq.data import Dictionary
12 |
13 |
14 | class TestDictionary(unittest.TestCase):
15 |
16 | def test_finalize(self):
17 | txt = [
18 | 'A B C D',
19 | 'B C D',
20 | 'C D',
21 | 'D',
22 | ]
23 | ref_ids1 = list(map(torch.IntTensor, [
24 | [4, 5, 6, 7, 2],
25 | [5, 6, 7, 2],
26 | [6, 7, 2],
27 | [7, 2],
28 | ]))
29 | ref_ids2 = list(map(torch.IntTensor, [
30 | [7, 6, 5, 4, 2],
31 | [6, 5, 4, 2],
32 | [5, 4, 2],
33 | [4, 2],
34 | ]))
35 |
36 | # build dictionary
37 | d = Dictionary()
38 | for line in txt:
39 | d.encode_line(line, add_if_not_exist=True)
40 |
41 | def get_ids(dictionary):
42 | ids = []
43 | for line in txt:
44 | ids.append(dictionary.encode_line(line, add_if_not_exist=False))
45 | return ids
46 |
47 | def assertMatch(ids, ref_ids):
48 | for toks, ref_toks in zip(ids, ref_ids):
49 | self.assertEqual(toks.size(), ref_toks.size())
50 | self.assertEqual(0, (toks != ref_toks).sum().item())
51 |
52 | ids = get_ids(d)
53 | assertMatch(ids, ref_ids1)
54 |
55 | # check finalized dictionary
56 | d.finalize()
57 | finalized_ids = get_ids(d)
58 | assertMatch(finalized_ids, ref_ids2)
59 |
60 | # write to disk and reload
61 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
62 | d.save(tmp_dict.name)
63 | d = Dictionary.load(tmp_dict.name)
64 | reload_ids = get_ids(d)
65 | assertMatch(reload_ids, ref_ids2)
66 | assertMatch(finalized_ids, reload_ids)
67 |
68 |
69 | if __name__ == '__main__':
70 | unittest.main()
71 |
--------------------------------------------------------------------------------
/tests/test_iterators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 |
8 | from fairseq.data import iterators
9 |
10 |
11 | class TestIterators(unittest.TestCase):
12 |
13 | def test_counting_iterator(self):
14 | x = list(range(10))
15 | itr = iterators.CountingIterator(x)
16 | self.assertTrue(itr.has_next())
17 | self.assertEqual(next(itr), 0)
18 | self.assertEqual(next(itr), 1)
19 | itr.skip(3)
20 | self.assertEqual(next(itr), 5)
21 | itr.skip(3)
22 | self.assertEqual(next(itr), 9)
23 | self.assertFalse(itr.has_next())
24 |
25 |
26 | if __name__ == '__main__':
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/tests/test_memory_efficient_fp16.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import argparse
7 | import unittest
8 |
9 | import torch
10 |
11 | from fairseq.optim.adam import FairseqAdam
12 | from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
13 |
14 |
15 | class TestMemoryEfficientFP16(unittest.TestCase):
16 |
17 | def test_load_state_dict(self):
18 | # define simple FP16 model
19 | model = torch.nn.Linear(5, 5).cuda().half()
20 | params = list(model.parameters())
21 |
22 | # initialize memory efficient FP16 optimizer
23 | optimizer = FairseqAdam(
24 | argparse.Namespace(
25 | lr=[0.00001],
26 | adam_betas='(0.9, 0.999)',
27 | adam_eps=1e-8,
28 | weight_decay=0.0,
29 | ),
30 | params,
31 | )
32 | me_optimizer = MemoryEfficientFP16Optimizer(
33 | argparse.Namespace(
34 | fp16_init_scale=1,
35 | fp16_scale_window=1,
36 | fp16_scale_tolerance=1,
37 | threshold_loss_scale=1,
38 | min_loss_scale=1e-4,
39 | ),
40 | params,
41 | optimizer,
42 | )
43 |
44 | # optimizer state is created in the first step
45 | loss = model(torch.rand(5).cuda().half()).sum()
46 | me_optimizer.backward(loss)
47 | me_optimizer.step()
48 |
49 | # reload state
50 | state = me_optimizer.state_dict()
51 | me_optimizer.load_state_dict(state)
52 | for k, v in me_optimizer.optimizer.state.items():
53 | self.assertTrue(k.dtype == torch.float16)
54 | for v_i in v.values():
55 | if torch.is_tensor(v_i):
56 | self.assertTrue(v_i.dtype == torch.float32)
57 |
58 |
59 | if __name__ == '__main__':
60 | unittest.main()
61 |
--------------------------------------------------------------------------------
/tests/test_multihead_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import unittest
8 | from fairseq.modules.multihead_attention import MultiheadAttention
9 |
10 |
11 | class TestMultiheadAttention(unittest.TestCase):
12 | def test_append_prev_key_padding_mask(self):
13 | bsz = 1
14 | src_len = 4
15 |
16 | cases = [
17 | # no padding mask
18 | (None, None, None),
19 | # current padding mask only
20 | (
21 | torch.tensor([[1]]).bool(),
22 | None,
23 | torch.tensor([[0, 0, 0, 1]]).bool(),
24 | ),
25 | # previous padding mask only
26 | (
27 | None,
28 | torch.tensor([[0, 1, 0]]).bool(),
29 | torch.tensor([[0, 1, 0, 0]]).bool(),
30 | ),
31 | # both padding masks
32 | (
33 | torch.tensor([[1]]).bool(),
34 | torch.tensor([[0, 1, 0]]).bool(),
35 | torch.tensor([[0, 1, 0, 1]]).bool(),
36 | ),
37 | ]
38 | for c in cases:
39 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
40 | c[0],
41 | c[1],
42 | batch_size=bsz,
43 | src_len=src_len,
44 | static_kv=False,
45 | )
46 |
47 | if key_padding_mask is not None:
48 | self.assertTrue(
49 | torch.all(torch.eq(key_padding_mask, c[2])),
50 | f'Unexpected resultant key padding mask: {key_padding_mask}'
51 | f' given current: {c[0]} and previous: {c[1]}',
52 | )
53 | self.assertEqual(key_padding_mask.size(0), bsz)
54 | self.assertEqual(key_padding_mask.size(1), src_len)
55 | else:
56 | self.assertIsNone(c[2])
57 |
58 |
59 | if __name__ == '__main__':
60 | unittest.main()
61 |
--------------------------------------------------------------------------------