├── .gitignore
├── .gitmodules
├── .isort.cfg
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── RELEASE.md
├── docs
├── Makefile
├── _static
│ └── theme_overrides.css
├── aqc_superb.jpg
├── command_line_tools.rst
├── conf.py
├── criterions.rst
├── data.rst
├── data2vec-aqc_final.png
├── docutils.conf
├── fairseq.gif
├── fairseq_logo.png
├── getting_started.rst
├── hydra_integration.md
├── index.rst
├── lr_scheduler.rst
├── make.bat
├── models.rst
├── modules.rst
├── optim.rst
├── overview.rst
├── requirements.txt
├── tasks.rst
├── tutorial_classifying_names.rst
└── tutorial_simple_lstm.rst
├── examples
├── .gitignore
├── __init__.py
├── data2vec
│ ├── README.md
│ ├── config
│ │ └── audio
│ │ │ └── pretraining
│ │ │ └── base_librispeech.yaml
│ └── models
│ │ └── data2vec_audio.py
├── language_model
│ ├── README.adaptive_inputs.md
│ ├── README.conv.md
│ ├── README.md
│ └── prepare-wikitext-103.sh
├── speech_recognition
│ ├── README.md
│ ├── __init__.py
│ ├── criterions
│ │ ├── ASG_loss.py
│ │ ├── __init__.py
│ │ └── cross_entropy_acc.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── asr_dataset.py
│ │ ├── collaters.py
│ │ ├── data_utils.py
│ │ └── replabels.py
│ ├── datasets
│ │ ├── asr_prep_json.py
│ │ └── prepare-librispeech.sh
│ ├── infer.py
│ ├── kaldi
│ │ ├── __init__.py
│ │ ├── add-self-loop-simple.cc
│ │ ├── config
│ │ │ └── kaldi_initializer.yaml
│ │ ├── kaldi_decoder.py
│ │ └── kaldi_initializer.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── vggtransformer.py
│ │ └── w2l_conv_glu_enc.py
│ ├── new
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── conf
│ │ │ ├── hydra
│ │ │ │ └── sweeper
│ │ │ │ │ └── ax.yaml
│ │ │ └── infer.yaml
│ │ ├── decoders
│ │ │ ├── __init__.py
│ │ │ ├── base_decoder.py
│ │ │ ├── decoder.py
│ │ │ ├── decoder_config.py
│ │ │ ├── flashlight_decoder.py
│ │ │ └── viterbi_decoder.py
│ │ └── infer.py
│ ├── tasks
│ │ ├── __init__.py
│ │ └── speech_recognition.py
│ ├── utils
│ │ └── wer_utils.py
│ └── w2l_decoder.py
└── wav2vec
│ ├── README.md
│ ├── __init__.py
│ ├── config
│ ├── finetuning
│ │ ├── base_100h.yaml
│ │ ├── base_10h.yaml
│ │ ├── base_10m.yaml
│ │ ├── base_1h.yaml
│ │ ├── base_960h.yaml
│ │ ├── vox_100h.yaml
│ │ ├── vox_10h.yaml
│ │ ├── vox_10m.yaml
│ │ ├── vox_1h.yaml
│ │ └── vox_960h.yaml
│ └── pretraining
│ │ ├── wav2vec2_base_librispeech.yaml
│ │ ├── wav2vec2_conformer_base_librispeech.yaml
│ │ ├── wav2vec2_conformer_large_librivox.yaml
│ │ ├── wav2vec2_large_librivox.yaml
│ │ ├── wav2vec2_large_librivox_tpu-pod.yaml
│ │ └── wav2vec2_large_librivox_tpu.yaml
│ ├── libri_labels.py
│ ├── scripts
│ └── binarize_manifest.sh
│ ├── unsupervised
│ ├── README.md
│ ├── __init__.py
│ ├── config
│ │ ├── finetuning
│ │ │ └── w2v_finetune.yaml
│ │ ├── gan
│ │ │ ├── w2vu.yaml
│ │ │ └── w2vu2.yaml
│ │ ├── generate
│ │ │ └── viterbi.yaml
│ │ ├── timit_matched
│ │ │ ├── test.uid
│ │ │ ├── train.uid
│ │ │ ├── train_text.uid
│ │ │ └── valid.uid
│ │ └── timit_unmatched
│ │ │ ├── test.uid
│ │ │ ├── train.uid
│ │ │ ├── train_text.uid
│ │ │ └── valid.uid
│ ├── data
│ │ ├── __init__.py
│ │ ├── extracted_features_dataset.py
│ │ └── random_input_dataset.py
│ ├── kaldi_self_train
│ │ ├── README.md
│ │ └── st
│ │ │ ├── cmd.sh
│ │ │ ├── decode_phone.sh
│ │ │ ├── decode_word_step1.sh
│ │ │ ├── decode_word_step2.sh
│ │ │ ├── local
│ │ │ ├── copy_aligned_text.py
│ │ │ ├── decode.sh
│ │ │ ├── prepare_data_from_w2v.py
│ │ │ ├── prepare_lang.sh
│ │ │ ├── prepare_lang_word.sh
│ │ │ ├── prepare_lm.sh
│ │ │ ├── score.sh
│ │ │ ├── show_wer.sh
│ │ │ ├── train_subset_lgbeam.sh
│ │ │ ├── unsup_select.py
│ │ │ ├── unsup_select_decode.sh
│ │ │ └── unsup_select_decode_word.sh
│ │ │ ├── path.sh
│ │ │ ├── steps
│ │ │ ├── steps_gan
│ │ │ ├── train_deltas.sh
│ │ │ ├── train_lda_mllt.sh
│ │ │ └── train_sat.sh
│ │ │ ├── train.sh
│ │ │ └── utils
│ ├── models
│ │ ├── __init__.py
│ │ └── wav2vec_u.py
│ ├── scripts
│ │ ├── apply_pca.py
│ │ ├── copy_labels.py
│ │ ├── filter_lexicon.py
│ │ ├── filter_tsv.py
│ │ ├── g2p_wrd_to_phn.py
│ │ ├── ltr_to_wrd.py
│ │ ├── mean_pool.py
│ │ ├── merge_clusters.py
│ │ ├── normalize_and_filter_text.py
│ │ ├── normalize_text.py
│ │ ├── pca.py
│ │ ├── phonemize_with_sil.py
│ │ ├── prepare_audio.sh
│ │ ├── prepare_audio_v2.sh
│ │ ├── prepare_text.sh
│ │ ├── prepare_timit.sh
│ │ ├── remove_silence.py
│ │ ├── vads.py
│ │ ├── wav2vec_apply_cluster_faiss.py
│ │ ├── wav2vec_cluster_faiss.py
│ │ ├── wav2vec_extract_features.py
│ │ ├── wer.py
│ │ └── wrd_to_ltr.py
│ ├── tasks
│ │ ├── __init__.py
│ │ └── unpaired_audio_text.py
│ └── w2vu_generate.py
│ ├── vq-wav2vec_featurize.py
│ ├── wav2vec_featurize.py
│ ├── wav2vec_manifest.py
│ └── xlsr
│ ├── README.md
│ └── config
│ └── finetune.yaml
├── fairseq
├── __init__.py
├── benchmark
│ ├── __init__.py
│ ├── benchmark_multihead_attention.py
│ ├── dummy_dataset.py
│ ├── dummy_lm.py
│ ├── dummy_masked_lm.py
│ ├── dummy_model.py
│ └── dummy_mt.py
├── binarizer.py
├── checkpoint_utils.py
├── clib
│ ├── cuda
│ │ ├── ngram_repeat_block_cuda.cpp
│ │ └── ngram_repeat_block_cuda_kernel.cu
│ ├── libbase
│ │ └── balanced_assignment.cpp
│ ├── libbleu
│ │ ├── libbleu.cpp
│ │ └── module.cpp
│ ├── libnat
│ │ └── edit_dist.cpp
│ └── libnat_cuda
│ │ ├── binding.cpp
│ │ ├── edit_dist.cu
│ │ └── edit_dist.h
├── config
│ ├── __init__.py
│ ├── config.yaml
│ └── model
│ │ ├── transformer_lm
│ │ ├── transformer_lm_baevski_gbw.yaml
│ │ ├── transformer_lm_baevski_wiki103.yaml
│ │ ├── transformer_lm_big.yaml
│ │ ├── transformer_lm_gbw.yaml
│ │ ├── transformer_lm_gpt.yaml
│ │ ├── transformer_lm_gpt2_big.yaml
│ │ ├── transformer_lm_gpt2_medium.yaml
│ │ ├── transformer_lm_gpt2_small.yaml
│ │ └── transformer_lm_wiki103.yaml
│ │ ├── wav2vec
│ │ └── vq_wav2vec_gumbel.yaml
│ │ └── wav2vec2
│ │ ├── wav2vec2_base.yaml
│ │ └── wav2vec2_large.yaml
├── criterions
│ ├── __init__.py
│ ├── adaptive_loss.py
│ ├── composite_loss.py
│ ├── cross_entropy.py
│ ├── ctc.py
│ ├── fairseq_criterion.py
│ ├── fastspeech2_loss.py
│ ├── hubert_criterion.py
│ ├── label_smoothed_cross_entropy.py
│ ├── label_smoothed_cross_entropy_latency_augmented.py
│ ├── label_smoothed_cross_entropy_with_alignment.py
│ ├── label_smoothed_cross_entropy_with_ctc.py
│ ├── legacy_masked_lm.py
│ ├── masked_lm.py
│ ├── model_criterion.py
│ ├── nat_loss.py
│ ├── sentence_prediction.py
│ ├── sentence_prediction_adapters.py
│ ├── sentence_ranking.py
│ ├── speech_to_speech_criterion.py
│ ├── speech_ulm_criterion.py
│ ├── tacotron2_loss.py
│ └── wav2vec_criterion.py
├── data
│ ├── __init__.py
│ ├── add_target_dataset.py
│ ├── append_token_dataset.py
│ ├── audio
│ │ ├── __init__.py
│ │ ├── audio_utils.py
│ │ ├── data_cfg.py
│ │ ├── feature_transforms
│ │ │ ├── __init__.py
│ │ │ ├── delta_deltas.py
│ │ │ ├── global_cmvn.py
│ │ │ ├── specaugment.py
│ │ │ └── utterance_cmvn.py
│ │ ├── frm_text_to_speech_dataset.py
│ │ ├── hubert_dataset.py
│ │ ├── multi_modality_dataset.py
│ │ ├── raw_audio_dataset.py
│ │ ├── speech_to_speech_dataset.py
│ │ ├── speech_to_text_dataset.py
│ │ ├── speech_to_text_joint_dataset.py
│ │ └── text_to_speech_dataset.py
│ ├── backtranslation_dataset.py
│ ├── base_wrapper_dataset.py
│ ├── bucket_pad_length_dataset.py
│ ├── codedataset.py
│ ├── colorize_dataset.py
│ ├── concat_dataset.py
│ ├── concat_sentences_dataset.py
│ ├── data_utils.py
│ ├── data_utils_fast.pyx
│ ├── denoising_dataset.py
│ ├── dictionary.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── byte_bpe.py
│ │ ├── byte_utils.py
│ │ ├── bytes.py
│ │ ├── characters.py
│ │ ├── fastbpe.py
│ │ ├── gpt2_bpe.py
│ │ ├── gpt2_bpe_utils.py
│ │ ├── hf_bert_bpe.py
│ │ ├── hf_byte_bpe.py
│ │ ├── moses_tokenizer.py
│ │ ├── nltk_tokenizer.py
│ │ ├── sentencepiece_bpe.py
│ │ ├── space_tokenizer.py
│ │ ├── subword_nmt_bpe.py
│ │ └── utils.py
│ ├── fairseq_dataset.py
│ ├── fasta_dataset.py
│ ├── huffman
│ │ ├── __init__.py
│ │ ├── huffman_coder.py
│ │ └── huffman_mmap_indexed_dataset.py
│ ├── id_dataset.py
│ ├── indexed_dataset.py
│ ├── iterators.py
│ ├── language_pair_dataset.py
│ ├── legacy
│ │ ├── __init__.py
│ │ ├── block_pair_dataset.py
│ │ ├── masked_lm_dataset.py
│ │ └── masked_lm_dictionary.py
│ ├── list_dataset.py
│ ├── lm_context_window_dataset.py
│ ├── lru_cache_dataset.py
│ ├── mask_tokens_dataset.py
│ ├── monolingual_dataset.py
│ ├── multi_corpus_dataset.py
│ ├── multi_corpus_sampled_dataset.py
│ ├── multilingual
│ │ ├── __init__.py
│ │ ├── multilingual_data_manager.py
│ │ ├── multilingual_utils.py
│ │ ├── sampled_multi_dataset.py
│ │ ├── sampled_multi_epoch_dataset.py
│ │ └── sampling_method.py
│ ├── nested_dictionary_dataset.py
│ ├── noising.py
│ ├── num_samples_dataset.py
│ ├── numel_dataset.py
│ ├── offset_tokens_dataset.py
│ ├── pad_dataset.py
│ ├── plasma_utils.py
│ ├── prepend_dataset.py
│ ├── prepend_token_dataset.py
│ ├── raw_label_dataset.py
│ ├── replace_dataset.py
│ ├── resampling_dataset.py
│ ├── roll_dataset.py
│ ├── round_robin_zip_datasets.py
│ ├── shorten_dataset.py
│ ├── sort_dataset.py
│ ├── strip_token_dataset.py
│ ├── subsample_dataset.py
│ ├── text_compressor.py
│ ├── token_block_dataset.py
│ ├── token_block_utils_fast.pyx
│ ├── transform_eos_concat_langpair_dataset.py
│ ├── transform_eos_dataset.py
│ └── transform_eos_lang_pair_dataset.py
├── dataclass
│ ├── __init__.py
│ ├── configs.py
│ ├── constants.py
│ ├── initialize.py
│ └── utils.py
├── distributed
│ ├── __init__.py
│ ├── distributed_timeout_wrapper.py
│ ├── fully_sharded_data_parallel.py
│ ├── legacy_distributed_data_parallel.py
│ ├── module_proxy_wrapper.py
│ ├── tpu_distributed_data_parallel.py
│ └── utils.py
├── file_chunker_utils.py
├── file_io.py
├── file_utils.py
├── hub_utils.py
├── incremental_decoding_utils.py
├── iterative_refinement_generator.py
├── logging
│ ├── __init__.py
│ ├── meters.py
│ ├── metrics.py
│ └── progress_bar.py
├── model_parallel
│ ├── __init__.py
│ ├── criterions
│ │ ├── __init__.py
│ │ └── vocab_parallel_cross_entropy.py
│ ├── megatron_trainer.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── pipeline_parallel_transformer
│ │ │ ├── __init__.py
│ │ │ ├── layers.py
│ │ │ └── model.py
│ │ ├── roberta
│ │ │ ├── __init__.py
│ │ │ └── model.py
│ │ ├── transformer.py
│ │ └── transformer_lm.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── multihead_attention.py
│ │ └── transformer_layer.py
├── models
│ ├── __init__.py
│ ├── bart
│ │ ├── __init__.py
│ │ ├── hub_interface.py
│ │ └── model.py
│ ├── composite_encoder.py
│ ├── distributed_fairseq_model.py
│ ├── ema
│ │ ├── __init__.py
│ │ └── ema.py
│ ├── fairseq_decoder.py
│ ├── fairseq_encoder.py
│ ├── fairseq_incremental_decoder.py
│ ├── fairseq_model.py
│ ├── fconv.py
│ ├── fconv_lm.py
│ ├── fconv_self_att.py
│ ├── hubert
│ │ ├── __init__.py
│ │ ├── hubert.py
│ │ └── hubert_asr.py
│ ├── huggingface
│ │ ├── __init__.py
│ │ └── hf_gpt2.py
│ ├── lightconv.py
│ ├── lightconv_lm.py
│ ├── lstm.py
│ ├── lstm_lm.py
│ ├── masked_lm.py
│ ├── model_utils.py
│ ├── multilingual_transformer.py
│ ├── nat
│ │ ├── __init__.py
│ │ ├── cmlm_transformer.py
│ │ ├── fairseq_nat_model.py
│ │ ├── insertion_transformer.py
│ │ ├── iterative_nonautoregressive_transformer.py
│ │ ├── levenshtein_transformer.py
│ │ ├── levenshtein_utils.py
│ │ ├── nat_crf_transformer.py
│ │ ├── nonautoregressive_ensembles.py
│ │ └── nonautoregressive_transformer.py
│ ├── roberta
│ │ ├── __init__.py
│ │ ├── alignment_utils.py
│ │ ├── enc_dec.py
│ │ ├── hub_interface.py
│ │ ├── model.py
│ │ ├── model_camembert.py
│ │ ├── model_gottbert.py
│ │ └── model_xlmr.py
│ ├── speech_to_speech
│ │ ├── __init__.py
│ │ ├── modules.py
│ │ ├── s2s_conformer.py
│ │ └── s2s_transformer.py
│ ├── speech_to_text
│ │ ├── __init__.py
│ │ ├── berard.py
│ │ ├── convtransformer.py
│ │ ├── hub_interface.py
│ │ ├── modules
│ │ │ ├── augmented_memory_attention.py
│ │ │ └── emformer.py
│ │ ├── multi_modality_model.py
│ │ ├── s2t_conformer.py
│ │ ├── s2t_transformer.py
│ │ ├── s2t_wav_transformer.py
│ │ ├── utils.py
│ │ └── xm_transformer.py
│ ├── text_to_speech
│ │ ├── __init__.py
│ │ ├── codehifigan.py
│ │ ├── fastspeech2.py
│ │ ├── hifigan.py
│ │ ├── hub_interface.py
│ │ ├── tacotron2.py
│ │ ├── tts_transformer.py
│ │ └── vocoder.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── transformer_base.py
│ │ ├── transformer_config.py
│ │ ├── transformer_decoder.py
│ │ ├── transformer_encoder.py
│ │ └── transformer_legacy.py
│ ├── transformer_align.py
│ ├── transformer_from_pretrained_xlm.py
│ ├── transformer_lm.py
│ ├── transformer_ulm.py
│ ├── wav2vec
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── wav2vec.py
│ │ ├── wav2vec2.py
│ │ └── wav2vec2_asr.py
│ └── xmod
│ │ ├── __init__.py
│ │ ├── hub_interface.py
│ │ ├── model.py
│ │ └── transformer_layer_xmod.py
├── modules
│ ├── __init__.py
│ ├── adaptive_input.py
│ ├── adaptive_softmax.py
│ ├── base_layer.py
│ ├── beamable_mm.py
│ ├── character_token_embedder.py
│ ├── checkpoint_activations.py
│ ├── conformer_layer.py
│ ├── conv_tbc.py
│ ├── cross_entropy.py
│ ├── cuda_utils.cu
│ ├── downsampled_multihead_attention.py
│ ├── dynamic_convolution.py
│ ├── dynamic_crf_layer.py
│ ├── dynamicconv_layer
│ │ ├── __init__.py
│ │ ├── cuda_function_gen.py
│ │ ├── dynamicconv_cuda.cpp
│ │ ├── dynamicconv_cuda.cuh
│ │ ├── dynamicconv_cuda_kernel.cu
│ │ ├── dynamicconv_layer.py
│ │ ├── dynamiconv_cpu.cpp
│ │ └── setup.py
│ ├── ema_module.py
│ ├── espnet_multihead_attention.py
│ ├── fairseq_dropout.py
│ ├── fp32_batch_norm.py
│ ├── fp32_group_norm.py
│ ├── fp32_instance_norm.py
│ ├── gelu.py
│ ├── grad_multiply.py
│ ├── gumbel_vector_quantizer.py
│ ├── kmeans_attention.py
│ ├── kmeans_vector_quantizer.py
│ ├── layer_drop.py
│ ├── layer_norm.py
│ ├── learned_positional_embedding.py
│ ├── lightconv_layer
│ │ ├── __init__.py
│ │ ├── cuda_function_gen.py
│ │ ├── lightconv_cuda.cpp
│ │ ├── lightconv_cuda.cuh
│ │ ├── lightconv_cuda_kernel.cu
│ │ ├── lightconv_layer.py
│ │ └── setup.py
│ ├── lightweight_convolution.py
│ ├── linearized_convolution.py
│ ├── location_attention.py
│ ├── lstm_cell_with_zoneout.py
│ ├── multihead_attention.py
│ ├── positional_embedding.py
│ ├── positional_encoding.py
│ ├── quant_noise.py
│ ├── quantization
│ │ ├── __init__.py
│ │ ├── pq
│ │ │ ├── __init__.py
│ │ │ ├── em.py
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── qconv.py
│ │ │ │ ├── qemb.py
│ │ │ │ └── qlinear.py
│ │ │ ├── pq.py
│ │ │ └── utils.py
│ │ ├── quantization_options.py
│ │ └── scalar
│ │ │ ├── __init__.py
│ │ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── qact.py
│ │ │ ├── qconv.py
│ │ │ ├── qemb.py
│ │ │ └── qlinear.py
│ │ │ ├── ops.py
│ │ │ └── utils.py
│ ├── rotary_positional_embedding.py
│ ├── same_pad.py
│ ├── scalar_bias.py
│ ├── sinusoidal_positional_embedding.py
│ ├── sparse_multihead_attention.py
│ ├── sparse_transformer_sentence_encoder.py
│ ├── sparse_transformer_sentence_encoder_layer.py
│ ├── transformer_layer.py
│ ├── transformer_sentence_encoder.py
│ ├── transformer_sentence_encoder_layer.py
│ ├── transpose_last.py
│ ├── unfold.py
│ └── vggblock.py
├── nan_detector.py
├── ngram_repeat_block.py
├── optim
│ ├── __init__.py
│ ├── adadelta.py
│ ├── adafactor.py
│ ├── adagrad.py
│ ├── adam.py
│ ├── adamax.py
│ ├── amp_optimizer.py
│ ├── bmuf.py
│ ├── composite.py
│ ├── cpu_adam.py
│ ├── dynamic_loss_scaler.py
│ ├── fairseq_optimizer.py
│ ├── fp16_optimizer.py
│ ├── fused_adam.py
│ ├── fused_lamb.py
│ ├── lr_scheduler
│ │ ├── __init__.py
│ │ ├── cosine_lr_scheduler.py
│ │ ├── fairseq_lr_scheduler.py
│ │ ├── fixed_schedule.py
│ │ ├── inverse_square_root_schedule.py
│ │ ├── manual_lr_scheduler.py
│ │ ├── pass_through.py
│ │ ├── polynomial_decay_schedule.py
│ │ ├── reduce_lr_on_plateau.py
│ │ ├── step_lr_scheduler.py
│ │ ├── tri_stage_lr_scheduler.py
│ │ └── triangular_lr_scheduler.py
│ ├── nag.py
│ ├── sgd.py
│ └── shard.py
├── options.py
├── pdb.py
├── quantization_utils.py
├── registry.py
├── scoring
│ ├── __init__.py
│ ├── bertscore.py
│ ├── bleu.py
│ ├── chrf.py
│ ├── meteor.py
│ ├── tokenizer.py
│ └── wer.py
├── search.py
├── sequence_generator.py
├── sequence_scorer.py
├── speech_generator.py
├── tasks
│ ├── __init__.py
│ ├── audio_finetuning.py
│ ├── audio_pretraining.py
│ ├── cross_lingual_lm.py
│ ├── denoising.py
│ ├── fairseq_task.py
│ ├── frm_text_to_speech.py
│ ├── hubert_pretraining.py
│ ├── language_modeling.py
│ ├── legacy_masked_lm.py
│ ├── masked_lm.py
│ ├── multilingual_denoising.py
│ ├── multilingual_language_modeling.py
│ ├── multilingual_masked_lm.py
│ ├── multilingual_translation.py
│ ├── online_backtranslation.py
│ ├── semisupervised_translation.py
│ ├── sentence_prediction.py
│ ├── sentence_prediction_adapters.py
│ ├── sentence_ranking.py
│ ├── simultaneous_translation.py
│ ├── speech_to_speech.py
│ ├── speech_to_text.py
│ ├── speech_ulm_task.py
│ ├── text_to_speech.py
│ ├── translation.py
│ ├── translation_from_pretrained_bart.py
│ ├── translation_from_pretrained_xlm.py
│ ├── translation_lev.py
│ └── translation_multi_simple_epoch.py
├── token_generation_constraints.py
├── tokenizer.py
├── trainer.py
├── utils.py
└── version.txt
├── fairseq_cli
├── __init__.py
├── eval_lm.py
├── generate.py
├── hydra_train.py
├── interactive.py
├── preprocess.py
├── score.py
├── train.py
└── validate.py
├── hubconf.py
├── pyproject.toml
├── release_utils.py
├── scripts
├── __init__.py
├── average_checkpoints.py
├── better_transformer.py
├── build_sym_alignment.py
├── compare_namespaces.py
├── compound_split_bleu.sh
├── constraints
│ ├── extract.py
│ └── validate.py
├── convert_dictionary.lua
├── convert_model.lua
├── count_docs.py
├── read_binarized.py
├── rm_pt.py
├── sacrebleu.sh
├── shard_docs.py
├── split_train_valid_docs.py
├── spm_decode.py
├── spm_encode.py
├── spm_train.py
└── test_fsdp.sh
├── setup.cfg
├── setup.py
├── tests
├── __init__.py
├── distributed
│ ├── __init__.py
│ ├── test_bmuf.py
│ ├── test_distributed_timeout_wrapper.py
│ ├── test_module_proxy_wrapper.py
│ ├── test_utils.py
│ └── utils.py
├── gpu
│ ├── __init__.py
│ ├── test_binaries_gpu.py
│ ├── test_ema_gpu.py
│ └── transformer_quantization_config.yaml
├── speech
│ ├── __init__.py
│ ├── test_convtransformer_simul_trans.py
│ ├── test_dual_input_wav_transformer.py
│ ├── test_dualinput_s2t_transformer.py
│ ├── test_fastspeech2.py
│ ├── test_s2s_transformer.py
│ ├── test_s2t_conformer.py
│ ├── test_s2t_transformer.py
│ ├── test_tts_transformer.py
│ ├── test_wav2vec2.py
│ └── test_xm_transformer.py
├── speech_recognition
│ ├── __init__.py
│ ├── asr_test_base.py
│ ├── test_collaters.py
│ ├── test_cross_entropy.py
│ ├── test_data_utils.py
│ └── test_vggtransformer.py
├── tasks
│ └── test_masked_lm.py
├── test_activation_checkpointing.py
├── test_amp_optimizer.py
├── test_average_checkpoints.py
├── test_backtranslation_dataset.py
├── test_binaries.py
├── test_binarizer.py
├── test_character_token_embedder.py
├── test_checkpoint_utils.py
├── test_concat_dataset.py
├── test_constraints.py
├── test_convtbc.py
├── test_data_utils.py
├── test_dataclass_utils.py
├── test_dataset.py
├── test_dictionary.py
├── test_ema.py
├── test_espnet_multihead_attention.py
├── test_export.py
├── test_file_chunker_utils.py
├── test_file_io.py
├── test_fp16_optimizer.py
├── test_hf_hub.py
├── test_huffman.py
├── test_inference_dropout.py
├── test_iopath.py
├── test_iterators.py
├── test_label_smoothing.py
├── test_lm_context_window.py
├── test_lstm_jitable.py
├── test_memory_efficient_fp16.py
├── test_metrics.py
├── test_multi_corpus_dataset.py
├── test_multi_corpus_sampled_dataset.py
├── test_multihead_attention.py
├── test_noising.py
├── test_online_backtranslation.py
├── test_plasma_utils.py
├── test_positional_encoding.py
├── test_reproducibility.py
├── test_resampling_dataset.py
├── test_roberta.py
├── test_rotary_positional_embedding.py
├── test_sequence_generator.py
├── test_sequence_scorer.py
├── test_sparse_multihead_attention.py
├── test_token_block_dataset.py
├── test_train.py
├── test_transformer.py
├── test_utils.py
├── test_valid_subset_checks.py
└── utils.py
└── train.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "fairseq/model_parallel/megatron"]
2 | path = fairseq/model_parallel/megatron
3 | url = https://github.com/ngoyal2707/Megatron-LM
4 | branch = fairseq
5 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | known_third_party = _cffi_backend,agg_results,aml,bitarray,boto3,botocore,dump_hubert_feature,dynamicconv_cuda,editdistance,faiss,fasttext,feature_utils,ffmpeg,g2p_en,h5py,hydra,hypothesis,indicnlp,inflect,iopath,joblib,kaldi_io,kenlm,libfb,librosa,lightconv_cuda,matplotlib,misc,mmpt,mmpt_cli,model,nltk,npy_append_array,numpy,omegaconf,pandas,pathbuilder,preprocessing,progressbar,pythainlp,random_sequence_shuffler,regex,sacrebleu,sacremoses,scipy,sentencepiece,setuptools,six,sklearn,soundfile,sweep,sweep_wmt_en2de_transformer_big_common,tabulate,torch,torchaudio,tqdm,unidecode,utils,videoreader,wav2vec_cluster_faiss,wget,yaml
3 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: 'build|stubs'
2 |
3 | default_language_version:
4 | python: python3
5 |
6 | repos:
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v4.1.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: check-ast
12 | - id: check-merge-conflict
13 | - id: no-commit-to-branch
14 | args: ['--branch=master']
15 | - id: check-added-large-files
16 | args: ['--maxkb=500']
17 | - id: end-of-file-fixer
18 |
19 | - repo: https://github.com/ambv/black
20 | rev: 22.3.0
21 | hooks:
22 | - id: black
23 | language_version: python3.8
24 |
25 | - repo: https://gitlab.com/pycqa/flake8
26 | rev: 3.9.2
27 | hooks:
28 | - id: flake8
29 | args: [
30 | # only error for syntax errors and undefined names
31 | "--select=E9,F63,F7,F82",
32 | ]
33 |
34 | - repo: https://github.com/pycqa/isort
35 | rev: 5.10.1
36 | hooks:
37 | - id: isort
38 | exclude: README.md
39 | additional_dependencies: [toml]
40 | args: ["--profile", "black"]
41 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include fairseq/version.txt
2 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | # Creating a New Release
2 |
3 | In order to create a new release:
4 |
5 | 1. Navigate to the [Fairseq Workflows](https://github.com/facebookresearch/fairseq/actions) and find the one named _Fairseq Release_.
6 |
7 | 2. Under _Run Workflow_ choose the branch `main` and for _Release Type_ enter either `major`, `minor`, or `patch`.
8 |
9 | 3. A branch named `$new_version-release` will be created where the `version.txt` file is updated. Merge those changes into `main`.
10 |
11 | 4. Make sure that a [new PYPI package](https://pypi.org/project/fairseq/) has been uploaded.
12 |
13 | 5. Make sure that a [new github release](https://github.com/facebookresearch/fairseq/releases) has been created.
14 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = fairseq
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/_static/theme_overrides.css:
--------------------------------------------------------------------------------
1 | .wy-table-responsive table td kbd {
2 | white-space: nowrap;
3 | }
4 | .wy-table-responsive table td {
5 | white-space: normal !important;
6 | }
7 | .wy-table-responsive {
8 | overflow: visible !important;
9 | }
10 |
--------------------------------------------------------------------------------
/docs/aqc_superb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/docs/aqc_superb.jpg
--------------------------------------------------------------------------------
/docs/criterions.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _Criterions:
5 |
6 | Criterions
7 | ==========
8 |
9 | Criterions compute the loss function given the model and batch, roughly::
10 |
11 | loss = criterion(model, batch)
12 |
13 | .. automodule:: fairseq.criterions
14 | :members:
15 |
16 | .. autoclass:: fairseq.criterions.FairseqCriterion
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21 | :members:
22 | :undoc-members:
23 | .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24 | :members:
25 | :undoc-members:
26 | .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27 | :members:
28 | :undoc-members:
29 | .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30 | :members:
31 | :undoc-members:
32 |
--------------------------------------------------------------------------------
/docs/data.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.data
5 |
6 | Data Loading and Utilities
7 | ==========================
8 |
9 | .. _datasets:
10 |
11 | Datasets
12 | --------
13 |
14 | **Datasets** define the data format and provide helpers for creating
15 | mini-batches.
16 |
17 | .. autoclass:: fairseq.data.FairseqDataset
18 | :members:
19 | .. autoclass:: fairseq.data.LanguagePairDataset
20 | :members:
21 | .. autoclass:: fairseq.data.MonolingualDataset
22 | :members:
23 |
24 | **Helper Datasets**
25 |
26 | These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
27 | provide additional functionality:
28 |
29 | .. autoclass:: fairseq.data.BacktranslationDataset
30 | :members:
31 | .. autoclass:: fairseq.data.ConcatDataset
32 | :members:
33 | .. autoclass:: fairseq.data.ResamplingDataset
34 | :members:
35 | .. autoclass:: fairseq.data.RoundRobinZipDatasets
36 | :members:
37 | .. autoclass:: fairseq.data.TransformEosDataset
38 | :members:
39 |
40 |
41 | Dictionary
42 | ----------
43 |
44 | .. autoclass:: fairseq.data.Dictionary
45 | :members:
46 |
47 |
48 | Iterators
49 | ---------
50 |
51 | .. autoclass:: fairseq.data.CountingIterator
52 | :members:
53 | .. autoclass:: fairseq.data.EpochBatchIterator
54 | :members:
55 | .. autoclass:: fairseq.data.GroupedIterator
56 | :members:
57 | .. autoclass:: fairseq.data.ShardedIterator
58 | :members:
59 |
--------------------------------------------------------------------------------
/docs/data2vec-aqc_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/docs/data2vec-aqc_final.png
--------------------------------------------------------------------------------
/docs/docutils.conf:
--------------------------------------------------------------------------------
1 | [writers]
2 | option-limit=0
3 |
--------------------------------------------------------------------------------
/docs/fairseq.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/docs/fairseq.gif
--------------------------------------------------------------------------------
/docs/fairseq_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/docs/fairseq_logo.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. fairseq documentation master file, created by
2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | :github_url: https://github.com/pytorch/fairseq
7 |
8 |
9 | fairseq documentation
10 | =====================
11 |
12 | Fairseq is a sequence modeling toolkit written in `PyTorch
13 | `_ that allows researchers and developers to
14 | train custom models for translation, summarization, language modeling and other
15 | text generation tasks.
16 |
17 | .. toctree::
18 | :maxdepth: 1
19 | :caption: Getting Started
20 |
21 | getting_started
22 | command_line_tools
23 |
24 | .. toctree::
25 | :maxdepth: 1
26 | :caption: Extending Fairseq
27 |
28 | overview
29 | tutorial_simple_lstm
30 | tutorial_classifying_names
31 |
32 | .. toctree::
33 | :maxdepth: 2
34 | :caption: Library Reference
35 |
36 | tasks
37 | models
38 | criterions
39 | optim
40 | lr_scheduler
41 | data
42 | modules
43 |
44 |
45 | Indices and tables
46 | ==================
47 |
48 | * :ref:`genindex`
49 | * :ref:`search`
50 |
--------------------------------------------------------------------------------
/docs/lr_scheduler.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _Learning Rate Schedulers:
5 |
6 | Learning Rate Schedulers
7 | ========================
8 |
9 | Learning Rate Schedulers update the learning rate over the course of training.
10 | Learning rates can be updated after each update via :func:`step_update` or at
11 | epoch boundaries via :func:`step`.
12 |
13 | .. automodule:: fairseq.optim.lr_scheduler
14 | :members:
15 |
16 | .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21 | :members:
22 | :undoc-members:
23 | .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24 | :members:
25 | :undoc-members:
26 | .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27 | :members:
28 | :undoc-members:
29 | .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30 | :members:
31 | :undoc-members:
32 | .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33 | :members:
34 | :undoc-members:
35 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=fairseq
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | Modules
2 | =======
3 |
4 | Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5 | be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6 |
7 | .. automodule:: fairseq.modules
8 | :members:
9 | :undoc-members:
10 |
--------------------------------------------------------------------------------
/docs/optim.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. _optimizers:
5 |
6 | Optimizers
7 | ==========
8 |
9 | Optimizers update the Model parameters based on the gradients.
10 |
11 | .. automodule:: fairseq.optim
12 | :members:
13 |
14 | .. autoclass:: fairseq.optim.FairseqOptimizer
15 | :members:
16 | :undoc-members:
17 |
18 | .. autoclass:: fairseq.optim.adadelta.Adadelta
19 | :members:
20 | :undoc-members:
21 | .. autoclass:: fairseq.optim.adagrad.Adagrad
22 | :members:
23 | :undoc-members:
24 | .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25 | :members:
26 | :undoc-members:
27 | .. autoclass:: fairseq.optim.adam.FairseqAdam
28 | :members:
29 | :undoc-members:
30 | .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31 | :members:
32 | :undoc-members:
33 | .. autoclass:: fairseq.optim.nag.FairseqNAG
34 | :members:
35 | :undoc-members:
36 | .. autoclass:: fairseq.optim.sgd.SGD
37 | :members:
38 | :undoc-members:
39 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx<2.0
2 | sphinx-argparse
3 |
--------------------------------------------------------------------------------
/docs/tasks.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | .. module:: fairseq.tasks
5 |
6 | .. _Tasks:
7 |
8 | Tasks
9 | =====
10 |
11 | Tasks store dictionaries and provide helpers for loading/iterating over
12 | Datasets, initializing the Model/Criterion and calculating the loss.
13 |
14 | Tasks can be selected via the ``--task`` command-line argument. Once selected, a
15 | task may expose additional command-line arguments for further configuration.
16 |
17 | Example usage::
18 |
19 | # setup the task (e.g., load dictionaries)
20 | task = fairseq.tasks.setup_task(args)
21 |
22 | # build model and criterion
23 | model = task.build_model(args)
24 | criterion = task.build_criterion(args)
25 |
26 | # load datasets
27 | task.load_dataset('train')
28 | task.load_dataset('valid')
29 |
30 | # iterate over mini-batches of data
31 | batch_itr = task.get_batch_iterator(
32 | task.dataset('train'), max_tokens=4096,
33 | )
34 | for batch in batch_itr:
35 | # compute the loss
36 | loss, sample_size, logging_output = task.get_loss(
37 | model, criterion, batch,
38 | )
39 | loss.backward()
40 |
41 |
42 | Translation
43 | -----------
44 |
45 | .. autoclass:: fairseq.tasks.translation.TranslationTask
46 |
47 | .. _language modeling:
48 |
49 | Language Modeling
50 | -----------------
51 |
52 | .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
53 |
54 |
55 | Adding new tasks
56 | ----------------
57 |
58 | .. autofunction:: fairseq.tasks.register_task
59 | .. autoclass:: fairseq.tasks.FairseqTask
60 | :members:
61 | :undoc-members:
62 |
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | !*/*.sh
2 | !*/*.md
3 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | try:
7 | from fairseq.version import __version__ # noqa
8 | except ImportError:
9 | pass
10 |
--------------------------------------------------------------------------------
/examples/language_model/README.conv.md:
--------------------------------------------------------------------------------
1 | # Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)
2 |
3 | ## Example usage
4 |
5 | First download and preprocess the data following the main [language modeling README](README.md).
6 |
7 | Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103`
8 | architecture:
9 | ```bash
10 | fairseq-train --task language_modeling \
11 | data-bin/wikitext-103 \
12 | --save-dir checkpoints/fconv_wikitext-103 \
13 | --arch fconv_lm_dauphin_wikitext103 \
14 | --adaptive-softmax-cutoff 10000,20000,200000 \
15 | --dropout 0.2 \
16 | --criterion adaptive_loss \
17 | --optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \
18 | --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
19 | --max-tokens 1024 --tokens-per-sample 1024 \
20 | --ddp-backend legacy_ddp \
21 | --max-epoch 35
22 | ```
23 |
24 | And evaluate with:
25 | ```bash
26 | fairseq-eval-lm data-bin/wikitext-103 --path checkpoints/fconv_wiki103/checkpoint_best.pt
27 | ```
28 |
29 | ## Citation
30 |
31 | ```bibtex
32 | @inproceedings{dauphin2017language,
33 | title={Language Modeling with Gated Convolutional Networks},
34 | author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
35 | booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
36 | pages={933--941},
37 | year={2017},
38 | organization={JMLR}
39 | }
40 | ```
41 |
--------------------------------------------------------------------------------
/examples/language_model/prepare-wikitext-103.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
3 |
4 | URLS=(
5 | "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip"
6 | )
7 | FILES=(
8 | "wikitext-103-v1.zip"
9 | )
10 |
11 | for ((i=0;i<${#URLS[@]};++i)); do
12 | file=${FILES[i]}
13 | if [ -f $file ]; then
14 | echo "$file already exists, skipping download"
15 | else
16 | url=${URLS[i]}
17 | wget "$url"
18 | if [ -f $file ]; then
19 | echo "$url successfully downloaded."
20 | else
21 | echo "$url not successfully downloaded."
22 | exit -1
23 | fi
24 | if [ ${file: -4} == ".tgz" ]; then
25 | tar zxvf $file
26 | elif [ ${file: -4} == ".tar" ]; then
27 | tar xvf $file
28 | elif [ ${file: -4} == ".zip" ]; then
29 | unzip $file
30 | fi
31 | fi
32 | done
33 | cd ..
34 |
--------------------------------------------------------------------------------
/examples/speech_recognition/__init__.py:
--------------------------------------------------------------------------------
1 | from . import criterions, models, tasks # noqa
2 |
--------------------------------------------------------------------------------
/examples/speech_recognition/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 |
5 | # ASG loss requires flashlight bindings
6 | files_to_skip = set()
7 | try:
8 | import flashlight.lib.sequence.criterion
9 | except ImportError:
10 | files_to_skip.add("ASG_loss.py")
11 |
12 | for file in sorted(os.listdir(os.path.dirname(__file__))):
13 | if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
14 | criterion_name = file[: file.find(".py")]
15 | importlib.import_module(
16 | "examples.speech_recognition.criterions." + criterion_name
17 | )
18 |
--------------------------------------------------------------------------------
/examples/speech_recognition/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .asr_dataset import AsrDataset
7 |
8 |
9 | __all__ = [
10 | "AsrDataset",
11 | ]
12 |
--------------------------------------------------------------------------------
/examples/speech_recognition/kaldi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/examples/speech_recognition/kaldi/__init__.py
--------------------------------------------------------------------------------
/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | data_dir: ???
4 | fst_dir: ???
5 | in_labels: ???
6 | kaldi_root: ???
7 | lm_arpa: ???
8 | blank_symbol:
9 |
--------------------------------------------------------------------------------
/examples/speech_recognition/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 |
5 | for file in sorted(os.listdir(os.path.dirname(__file__))):
6 | if file.endswith(".py") and not file.startswith("_"):
7 | model_name = file[: file.find(".py")]
8 | importlib.import_module("examples.speech_recognition.models." + model_name)
9 |
--------------------------------------------------------------------------------
/examples/speech_recognition/new/README.md:
--------------------------------------------------------------------------------
1 | # Flashlight Decoder
2 |
3 | This script runs decoding for pre-trained speech recognition models.
4 |
5 | ## Usage
6 |
7 | Assuming a few variables:
8 |
9 | ```bash
10 | checkpoint=
11 | data=
12 | lm_model=
13 | lexicon=
14 | ```
15 |
16 | Example usage for decoding a fine-tuned Wav2Vec model:
17 |
18 | ```bash
19 | python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
20 | task=audio_pretraining \
21 | task.data=$data \
22 | task.labels=ltr \
23 | common_eval.path=$checkpoint \
24 | decoding.type=kenlm \
25 | decoding.lexicon=$lexicon \
26 | decoding.lmpath=$lm_model \
27 | dataset.gen_subset=dev_clean,dev_other,test_clean,test_other
28 | ```
29 |
30 | Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`):
31 |
32 | ```bash
33 | python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
34 | hydra/sweeper=ax \
35 | task=audio_pretraining \
36 | task.data=$data \
37 | task.labels=ltr \
38 | common_eval.path=$checkpoint \
39 | decoding.type=kenlm \
40 | decoding.lexicon=$lexicon \
41 | decoding.lmpath=$lm_model \
42 | dataset.gen_subset=dev_other
43 | ```
44 |
--------------------------------------------------------------------------------
/examples/speech_recognition/new/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/examples/speech_recognition/new/__init__.py
--------------------------------------------------------------------------------
/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml:
--------------------------------------------------------------------------------
1 | # @package hydra.sweeper
2 | _target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
3 | max_batch_size: null
4 | ax_config:
5 | max_trials: 128
6 | early_stop:
7 | minimize: true
8 | max_epochs_without_improvement: 10
9 | epsilon: 0.025
10 | experiment:
11 | name: ${dataset.gen_subset}
12 | objective_name: wer
13 | minimize: true
14 | parameter_constraints: null
15 | outcome_constraints: null
16 | status_quo: null
17 | client:
18 | verbose_logging: false
19 | random_seed: null
20 | params:
21 | decoding.lmweight:
22 | type: range
23 | bounds: [0.0, 5.0]
24 | decoding.wordscore:
25 | type: range
26 | bounds: [-5.0, 5.0]
27 | decoding.silweight:
28 | type: range
29 | bounds: [ -8.0, 0.0 ]
30 |
--------------------------------------------------------------------------------
/examples/speech_recognition/new/conf/infer.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | defaults:
4 | - task: null
5 | - model: null
6 |
7 | hydra:
8 | run:
9 | dir: ${common_eval.results_path}/${dataset.gen_subset}
10 | sweep:
11 | dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
12 | subdir: ${dataset.gen_subset}
13 | common_eval:
14 | results_path: null
15 | path: null
16 | post_process: letter
17 | quiet: true
18 | dataset:
19 | max_tokens: 3000000
20 | gen_subset: test
21 | distributed_training:
22 | distributed_world_size: 1
23 | decoding:
24 | beam: 5
25 | type: viterbi
26 |
--------------------------------------------------------------------------------
/examples/speech_recognition/new/decoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/examples/speech_recognition/new/decoders/__init__.py
--------------------------------------------------------------------------------
/examples/speech_recognition/new/decoders/decoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from typing import Union
9 |
10 | from fairseq.data.dictionary import Dictionary
11 |
12 | from .decoder_config import DecoderConfig, FlashlightDecoderConfig
13 | from .base_decoder import BaseDecoder
14 |
15 |
16 | def Decoder(
17 | cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary
18 | ) -> BaseDecoder:
19 |
20 | if cfg.type == "viterbi":
21 | from .viterbi_decoder import ViterbiDecoder
22 |
23 | return ViterbiDecoder(tgt_dict)
24 | if cfg.type == "kenlm":
25 | from .flashlight_decoder import KenLMDecoder
26 |
27 | return KenLMDecoder(cfg, tgt_dict)
28 | if cfg.type == "fairseqlm":
29 | from .flashlight_decoder import FairseqLMDecoder
30 |
31 | return FairseqLMDecoder(cfg, tgt_dict)
32 | raise NotImplementedError(f"Invalid decoder name: {cfg.name}")
33 |
--------------------------------------------------------------------------------
/examples/speech_recognition/new/decoders/viterbi_decoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 |
10 | from typing import List, Dict
11 |
12 | from .base_decoder import BaseDecoder
13 |
14 |
15 | class ViterbiDecoder(BaseDecoder):
16 | def decode(
17 | self,
18 | emissions: torch.FloatTensor,
19 | ) -> List[List[Dict[str, torch.LongTensor]]]:
20 | def get_pred(e):
21 | toks = e.argmax(dim=-1).unique_consecutive()
22 | return toks[toks != self.blank]
23 |
24 | return [[{"tokens": get_pred(x), "score": 0}] for x in emissions]
25 |
--------------------------------------------------------------------------------
/examples/speech_recognition/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 |
5 | for file in sorted(os.listdir(os.path.dirname(__file__))):
6 | if file.endswith(".py") and not file.startswith("_"):
7 | task_name = file[: file.find(".py")]
8 | importlib.import_module("examples.speech_recognition.tasks." + task_name)
9 |
--------------------------------------------------------------------------------
/examples/wav2vec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/examples/wav2vec/__init__.py
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/base_100h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: true
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 3200000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 2
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 80000
34 | lr: [0.00003]
35 | sentence_avg: true
36 | update_freq: [4]
37 |
38 | optimizer:
39 | _name: adam
40 | adam_betas: (0.9,0.98)
41 | adam_eps: 1e-08
42 |
43 | lr_scheduler:
44 | _name: tri_stage
45 | phase_ratio: [0.1, 0.4, 0.5]
46 | final_lr_scale: 0.05
47 |
48 | model:
49 | _name: wav2vec_ctc
50 | w2v_path: ???
51 | apply_mask: true
52 | mask_prob: 0.65
53 | mask_channel_prob: 0.5
54 | mask_channel_length: 64
55 | layerdrop: 0.1
56 | activation_dropout: 0.1
57 | feature_grad_mult: 0.0
58 | freeze_finetune_updates: 0
59 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/base_10h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 50
10 | save_interval_updates: 10000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: false
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 3200000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 50
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 2
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 20000
39 | lr: [0.00005]
40 | sentence_avg: true
41 | update_freq: [4]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.5
59 | mask_channel_length: 64
60 | layerdrop: 0.05
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/base_10m.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: false
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 3200000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 2
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.00005]
40 | sentence_avg: true
41 | update_freq: [4]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/base_1h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 50
10 | save_interval_updates: 1000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: false
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 3200000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 2
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.00005]
40 | sentence_avg: true
41 | update_freq: [4]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/base_960h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: false
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 3200000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 8
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 320000
34 | lr: [0.0001]
35 | sentence_avg: true
36 |
37 | optimizer:
38 | _name: adam
39 | adam_betas: (0.9,0.98)
40 | adam_eps: 1e-08
41 |
42 | lr_scheduler:
43 | _name: tri_stage
44 | phase_ratio: [0.1, 0.4, 0.5]
45 | final_lr_scale: 0.05
46 |
47 | model:
48 | _name: wav2vec_ctc
49 | w2v_path: ???
50 | apply_mask: true
51 | mask_prob: 0.5
52 | mask_channel_prob: 0.1
53 | mask_channel_length: 64
54 | layerdrop: 0.1
55 | activation_dropout: 0.1
56 | feature_grad_mult: 0.0
57 | freeze_finetune_updates: 0
58 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/vox_100h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: true
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 1280000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 4
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 80000
34 | lr: [0.00003]
35 | sentence_avg: true
36 | update_freq: [5]
37 |
38 | optimizer:
39 | _name: adam
40 | adam_betas: (0.9,0.98)
41 | adam_eps: 1e-08
42 |
43 | lr_scheduler:
44 | _name: tri_stage
45 | phase_ratio: [0.1, 0.4, 0.5]
46 | final_lr_scale: 0.05
47 |
48 | model:
49 | _name: wav2vec_ctc
50 | w2v_path: ???
51 | apply_mask: true
52 | mask_prob: 0.5
53 | mask_channel_prob: 0.5
54 | mask_channel_length: 64
55 | layerdrop: 0.1
56 | activation_dropout: 0.1
57 | feature_grad_mult: 0.0
58 | freeze_finetune_updates: 10000
59 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/vox_10h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 50
10 | save_interval_updates: 10000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: true
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 1280000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 50
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 4
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 20000
39 | lr: [0.0001]
40 | sentence_avg: true
41 | update_freq: [5]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.75
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/vox_10m.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: true
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 1280000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 4
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.0001]
40 | sentence_avg: true
41 | update_freq: [5]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.65
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/vox_1h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval: 1000
10 | save_interval_updates: 50
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 | best_checkpoint_metric: wer
14 |
15 | task:
16 | _name: audio_finetuning
17 | data: ???
18 | normalize: true
19 | labels: ltr
20 |
21 | dataset:
22 | num_workers: 6
23 | max_tokens: 1280000
24 | skip_invalid_size_inputs_valid_test: true
25 | validate_after_updates: 10000
26 | validate_interval: 1000
27 | valid_subset: dev_other
28 |
29 | distributed_training:
30 | ddp_backend: legacy_ddp
31 | distributed_world_size: 4
32 |
33 | criterion:
34 | _name: ctc
35 | zero_infinity: true
36 |
37 | optimization:
38 | max_update: 13000
39 | lr: [0.0003]
40 | sentence_avg: true
41 | update_freq: [5]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-08
47 |
48 | lr_scheduler:
49 | _name: tri_stage
50 | phase_ratio: [0.1, 0.4, 0.5]
51 | final_lr_scale: 0.05
52 |
53 | model:
54 | _name: wav2vec_ctc
55 | w2v_path: ???
56 | apply_mask: true
57 | mask_prob: 0.75
58 | mask_channel_prob: 0.25
59 | mask_channel_length: 64
60 | layerdrop: 0.1
61 | activation_dropout: 0.1
62 | feature_grad_mult: 0.0
63 | freeze_finetune_updates: 10000
64 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/finetuning/vox_960h.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | no_epoch_checkpoints: true
10 | best_checkpoint_metric: wer
11 |
12 | task:
13 | _name: audio_finetuning
14 | data: ???
15 | normalize: true
16 | labels: ltr
17 |
18 | dataset:
19 | num_workers: 6
20 | max_tokens: 1280000
21 | skip_invalid_size_inputs_valid_test: true
22 | valid_subset: dev_other
23 |
24 | distributed_training:
25 | ddp_backend: legacy_ddp
26 | distributed_world_size: 24
27 |
28 | criterion:
29 | _name: ctc
30 | zero_infinity: true
31 |
32 | optimization:
33 | max_update: 320000
34 | lr: [0.00003]
35 | sentence_avg: true
36 |
37 | optimizer:
38 | _name: adam
39 | adam_betas: (0.9,0.98)
40 | adam_eps: 1e-08
41 |
42 | lr_scheduler:
43 | _name: tri_stage
44 | phase_ratio: [0.1, 0.4, 0.5]
45 | final_lr_scale: 0.05
46 |
47 | model:
48 | _name: wav2vec_ctc
49 | w2v_path: ???
50 | apply_mask: true
51 | mask_prob: 0.5
52 | mask_channel_prob: 0.25
53 | mask_channel_length: 64
54 | layerdrop: 0.1
55 | activation_dropout: 0.1
56 | feature_grad_mult: 0.0
57 | freeze_finetune_updates: 10000
58 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval_updates: 25000
10 | keep_interval_updates: 1
11 | no_epoch_checkpoints: true
12 |
13 | task:
14 | _name: audio_pretraining
15 | data: ???
16 | max_sample_size: 250000
17 | min_sample_size: 32000
18 | normalize: false
19 |
20 | dataset:
21 | num_workers: 6
22 | max_tokens: 1400000
23 | skip_invalid_size_inputs_valid_test: true
24 |
25 | distributed_training:
26 | distributed_world_size: 64
27 | ddp_backend: legacy_ddp
28 |
29 | criterion:
30 | _name: wav2vec
31 | infonce: true
32 | log_keys: ["prob_perplexity","code_perplexity","temp"]
33 | loss_weights: [0.1, 10]
34 |
35 | optimization:
36 | max_update: 400000
37 | lr: [0.0005]
38 |
39 | optimizer:
40 | _name: adam
41 | adam_betas: (0.9,0.98)
42 | adam_eps: 1e-06
43 | weight_decay: 0.01
44 |
45 | lr_scheduler:
46 | _name: polynomial_decay
47 | warmup_updates: 32000
48 |
49 | model:
50 | _name: wav2vec2
51 | quantize_targets: true
52 | final_dim: 256
53 | encoder_layerdrop: 0.05
54 | dropout_input: 0.1
55 | dropout_features: 0.1
56 | feature_grad_mult: 0.1
57 | encoder_embed_dim: 768
58 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval_updates: 25000
10 | keep_interval_updates: 1
11 | no_epoch_checkpoints: true
12 |
13 | task:
14 | _name: audio_pretraining
15 | data: ???
16 | max_sample_size: 250000
17 | min_sample_size: 32000
18 | normalize: false
19 |
20 | dataset:
21 | num_workers: 6
22 | max_tokens: 1400000
23 | skip_invalid_size_inputs_valid_test: true
24 |
25 | distributed_training:
26 | distributed_world_size: 64
27 | ddp_backend: legacy_ddp
28 |
29 | criterion:
30 | _name: wav2vec
31 | infonce: true
32 | log_keys: ["prob_perplexity","code_perplexity","temp"]
33 | loss_weights: [0.1, 10]
34 |
35 | optimization:
36 | max_update: 400000
37 | lr: [0.0005]
38 |
39 | optimizer:
40 | _name: adam
41 | adam_betas: (0.9,0.98)
42 | adam_eps: 1e-06
43 | weight_decay: 0.01
44 |
45 | lr_scheduler:
46 | _name: polynomial_decay
47 | warmup_updates: 32000
48 |
49 | model:
50 | _name: wav2vec2
51 | quantize_targets: true
52 | final_dim: 256
53 | encoder_layerdrop: 0.05
54 | dropout_input: 0.1
55 | dropout_features: 0.1
56 | feature_grad_mult: 0.1
57 | encoder_embed_dim: 768
58 | layer_type: conformer
59 | attn_type: espnet
60 | pos_enc_type: rel_pos
61 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval_updates: 25000
10 | keep_interval_updates: 1
11 | no_epoch_checkpoints: true
12 |
13 | task:
14 | _name: audio_pretraining
15 | data: ???
16 | max_sample_size: 320000
17 | min_sample_size: 32000
18 | normalize: true
19 |
20 | dataset:
21 | num_workers: 6
22 | max_tokens: 1200000
23 | skip_invalid_size_inputs_valid_test: true
24 |
25 | distributed_training:
26 | distributed_world_size: 128
27 | ddp_backend: legacy_ddp
28 |
29 | criterion:
30 | _name: wav2vec
31 | infonce: true
32 | log_keys: ["prob_perplexity","code_perplexity","temp"]
33 | loss_weights: [0.1, 0]
34 |
35 | optimization:
36 | max_update: 1000000
37 | lr: [0.005]
38 |
39 | optimizer:
40 | _name: adam
41 | adam_betas: (0.9,0.98)
42 | adam_eps: 1e-06
43 | weight_decay: 0.01
44 |
45 | lr_scheduler:
46 | _name: polynomial_decay
47 | warmup_updates: 32000
48 |
49 | model:
50 | _name: wav2vec2
51 | quantize_targets: true
52 | extractor_mode: layer_norm
53 | layer_norm_first: true
54 | final_dim: 768
55 | latent_temp: [2.0,0.1,0.999995]
56 | encoder_layerdrop: 0.00
57 | dropout_input: 0.0
58 | dropout_features: 0.0
59 | dropout: 0.0
60 | attention_dropout: 0.0
61 | conv_bias: true
62 |
63 | encoder_layers: 24
64 | encoder_embed_dim: 1024
65 | encoder_ffn_embed_dim: 4096
66 | encoder_attention_heads: 16
67 |
68 | feature_grad_mult: 1.0
69 |
70 | layer_type: conformer
71 | attn_type: espnet
72 | pos_enc_type: rel_pos
73 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 |
8 | checkpoint:
9 | save_interval_updates: 25000
10 | keep_interval_updates: 1
11 | no_epoch_checkpoints: true
12 |
13 | task:
14 | _name: audio_pretraining
15 | data: ???
16 | max_sample_size: 320000
17 | min_sample_size: 32000
18 | normalize: true
19 |
20 | dataset:
21 | batch_size: 4
22 | num_workers: 6
23 | max_tokens: 1200000
24 | skip_invalid_size_inputs_valid_test: true
25 |
26 | distributed_training:
27 | distributed_world_size: 128
28 | ddp_backend: legacy_ddp
29 |
30 | criterion:
31 | _name: wav2vec
32 | infonce: true
33 | log_keys: ["prob_perplexity","code_perplexity","temp"]
34 | loss_weights: [0.1, 0]
35 |
36 | optimization:
37 | max_update: 1000000
38 | lr: [0.005]
39 |
40 | optimizer:
41 | _name: adam
42 | adam_betas: (0.9,0.98)
43 | adam_eps: 1e-06
44 | weight_decay: 0.01
45 |
46 | lr_scheduler:
47 | _name: polynomial_decay
48 | warmup_updates: 32000
49 |
50 | model:
51 | _name: wav2vec2
52 | quantize_targets: true
53 | extractor_mode: layer_norm
54 | layer_norm_first: true
55 | final_dim: 768
56 | latent_temp: [2.0,0.1,0.999995]
57 | encoder_layerdrop: 0.00
58 | dropout_input: 0.0
59 | dropout_features: 0.0
60 | dropout: 0.0
61 | attention_dropout: 0.0
62 | conv_bias: true
63 |
64 | encoder_layers: 24
65 | encoder_embed_dim: 1024
66 | encoder_ffn_embed_dim: 4096
67 | encoder_attention_heads: 16
68 |
69 | feature_grad_mult: 1.0
70 |
71 |
--------------------------------------------------------------------------------
/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | tpu: true
5 | fp16: false
6 | log_format: json
7 | log_interval: 10
8 |
9 | checkpoint:
10 | save_interval_updates: 25000
11 | keep_interval_updates: 1
12 | no_epoch_checkpoints: true
13 |
14 | task:
15 | _name: audio_pretraining
16 | data: ???
17 | max_sample_size: 250000
18 | min_sample_size: 32000
19 | normalize: true
20 | num_batch_buckets: 3
21 | precompute_mask_indices: true
22 | enable_padding: true
23 |
24 | dataset:
25 | num_workers: 6
26 | max_tokens: 1200000
27 | skip_invalid_size_inputs_valid_test: true
28 |
29 | distributed_training:
30 | distributed_world_size: 128
31 | ddp_backend: legacy_ddp
32 |
33 | criterion:
34 | _name: wav2vec
35 | infonce: true
36 | log_keys: ["prob_perplexity","code_perplexity","temp"]
37 | loss_weights: [0.1, 0]
38 |
39 | optimization:
40 | max_update: 1000000
41 | lr: [0.005]
42 |
43 | optimizer:
44 | _name: adam
45 | adam_betas: (0.9,0.98)
46 | adam_eps: 1e-06
47 | weight_decay: 0.01
48 |
49 | lr_scheduler:
50 | _name: polynomial_decay
51 | warmup_updates: 32000
52 |
53 | model:
54 | _name: wav2vec2
55 | quantize_targets: true
56 | extractor_mode: layer_norm
57 | layer_norm_first: true
58 | final_dim: 768
59 | latent_temp: [2.0,0.1,0.999995]
60 | encoder_layerdrop: 0.00
61 | dropout_input: 0.0
62 | dropout_features: 0.0
63 | dropout: 0.0
64 | attention_dropout: 0.0
65 | conv_bias: true
66 |
67 | encoder_layers: 24
68 | encoder_embed_dim: 1024
69 | encoder_ffn_embed_dim: 4096
70 | encoder_attention_heads: 16
71 |
72 | feature_grad_mult: 1.0
73 |
--------------------------------------------------------------------------------
/examples/wav2vec/scripts/binarize_manifest.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # usage: bash binarize_manifest
4 |
5 | DEST_DIR=$1
6 | TRAIN_SPLIT=$2
7 | VALID_SPLIT=$3
8 | FAIRSEQ_ROOT=$4
9 |
10 | mkdir -p $DEST_DIR
11 |
12 | # split file path and lengths into separate files
13 | cut -f1 $TRAIN_SPLIT.tsv > $DEST_DIR/train_fnames.txt
14 | cut -f1 $VALID_SPLIT.tsv > $DEST_DIR/valid_fnames.txt
15 | cut -f2 $TRAIN_SPLIT.tsv > $DEST_DIR/train.lengths
16 | cut -f2 $VALID_SPLIT.tsv > $DEST_DIR/valid.lengths
17 |
18 | # copy root directory
19 | head -1 $TRAIN_SPLIT.tsv > $DEST_DIR/train.root
20 | head -1 $VALID_SPLIT.tsv > $DEST_DIR/valid.root
21 |
22 | # remove root directory
23 | sed -i '1d' $DEST_DIR/train_fnames.txt
24 | sed -i '1d' $DEST_DIR/valid_fnames.txt
25 | sed -i '1d' $DEST_DIR/train.lengths
26 | sed -i '1d' $DEST_DIR/valid.lengths
27 |
28 | # insert spaces between characters
29 | sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/train_fnames.txt
30 | sed -i -e 's/\(.\)/\1 /g' $DEST_DIR/valid_fnames.txt
31 |
32 | # run preprocessor
33 | PYTHONPATH=$FAIRSEQ_ROOT python $FAIRSEQ_ROOT/fairseq_cli/preprocess.py --dataset-impl mmap --trainpref $DEST_DIR/train_fnames.txt --validpref $DEST_DIR/valid_fnames.txt --workers 60 --only-source --destdir $DEST_DIR
34 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/examples/wav2vec/unsupervised/__init__.py
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 | tensorboard_logdir: tb
8 |
9 | checkpoint:
10 | no_epoch_checkpoints: true
11 | save_interval_updates: 20000
12 |
13 | task:
14 | _name: audio_finetuning
15 | data: ???
16 | normalize: true
17 | labels: ltr
18 |
19 | dataset:
20 | num_workers: 6
21 | max_tokens: 800000
22 | skip_invalid_size_inputs_valid_test: true
23 | train_subset: train
24 | valid_subset: valid
25 |
26 | distributed_training:
27 | ddp_backend: legacy_ddp
28 | distributed_world_size: 8
29 | find_unused_parameters: True
30 |
31 | criterion:
32 | _name: ctc
33 | zero_infinity: true
34 | post_process: letter
35 |
36 | optimization:
37 | max_update: 80000
38 | lr: [0.00003]
39 | sentence_avg: true
40 | update_freq: [1]
41 |
42 | optimizer:
43 | _name: adam
44 | adam_betas: (0.9,0.98)
45 | adam_eps: 1e-08
46 |
47 | lr_scheduler:
48 | _name: tri_stage
49 | phase_ratio: [0.1, 0.4, 0.5]
50 | final_lr_scale: 0.05
51 |
52 | model:
53 | _name: wav2vec_ctc
54 | w2v_path: ???
55 | apply_mask: true
56 | mask_prob: 0.25
57 | mask_channel_prob: 0.1
58 | mask_channel_length: 64
59 | layerdrop: 0.1
60 | activation_dropout: 0.1
61 | feature_grad_mult: 0.0
62 | freeze_finetune_updates: 0
63 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/config/generate/viterbi.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | fairseq:
4 | task:
5 | _name: unpaired_audio_text
6 | labels: phn
7 | data: ???
8 | sort_by_length: false
9 | shuffle: false
10 | text_data: ''
11 |
12 | common_eval:
13 | path: ???
14 | quiet: true
15 |
16 | dataset:
17 | gen_subset: valid
18 | batch_size: 1
19 |
20 | w2l_decoder: VITERBI
21 | post_process: silence
22 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .extracted_features_dataset import ExtractedFeaturesDataset
7 | from .random_input_dataset import RandomInputDataset
8 |
9 |
10 | __all__ = [
11 | "ExtractedFeaturesDataset",
12 | "RandomInputDataset",
13 | ]
14 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh:
--------------------------------------------------------------------------------
1 | # you can change cmd.sh depending on what type of queue you are using.
2 | # If you have no queueing system and want to run on a local machine, you
3 | # can change all instances 'queue.pl' to run.pl (but be careful and run
4 | # commands one by one: most recipes will exhaust the memory on your
5 | # machine). queue.pl works with GridEngine (qsub). slurm.pl works
6 | # with slurm. Different queues are configured differently, with different
7 | # queue names and different ways of specifying things like memory;
8 | # to account for these differences you can create and edit the file
9 | # conf/queue.conf to match your queue's configuration. Search for
10 | # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information,
11 | # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl.
12 |
13 | export train_cmd="run.pl --mem 2G"
14 | export decode_cmd="run.pl --mem 4G"
15 | export mkgraph_cmd="run.pl --mem 8G"
16 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # decode into phones (and prepare a new data directory for HMM outputs)
4 |
5 | . ./path.sh
6 |
7 | set -eu
8 |
9 | out_dir= # same as in train.sh
10 | dec_lmparam= # LM hyperparameters (e.g., 7.0.0)
11 | dec_exp=
12 | dec_script=
13 | dec_splits="train valid"
14 | dec_data_dir=$out_dir/dec_data # where to write HMM output
15 |
16 | data_dir=${out_dir}/data
17 |
18 | local/decode.sh --nj 40 --graph_name graph \
19 | --val_sets "$dec_splits" --decode_script $dec_script \
20 | $out_dir/exp/$dec_exp $data_dir $data_dir/lang_test
21 |
22 | if [ ! -z $dec_lmparam ]; then
23 | for x in $dec_splits; do
24 | mkdir -p $dec_data_dir/$x
25 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/
26 |
27 | tra=$out_dir/exp/$dec_exp/decode_${x}/scoring/${dec_lmparam}.tra
28 | cat $tra | utils/int2sym.pl -f 2- $data_dir/lang/words.txt | \
29 | sed 's:::g' | sed 's:::g' > $dec_data_dir/${x}/text
30 | utils/fix_data_dir.sh $dec_data_dir/${x}
31 | echo "WER on ${x} is" $(compute-wer ark:$data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-)
32 | done
33 | fi
34 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # prepare a new data directory of HMM word output
4 |
5 | . ./path.sh
6 |
7 | set -eu
8 |
9 | out_dir= # same as in train.sh
10 | dec_lmparam= # LM hyperparameters (e.g., 7.0.0)
11 |
12 | dec_exp=tri3b # what HMM stage to decode (e.g., tri3b)
13 | dec_suffix=word
14 | dec_splits="train valid"
15 | dec_data_dir=$out_dir/dec_data_word # where to write HMM output
16 |
17 | data_dir=$out_dir/data
18 | wrd_data_dir=$out_dir/data_word
19 |
20 | for x in $dec_splits; do
21 | mkdir -p $dec_data_dir/$x
22 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $dec_data_dir/$x/
23 |
24 | tra=$out_dir/exp/$dec_exp/decode${dec_suffix}_${x}/scoring/${dec_lmparam}.tra
25 | cat $tra | utils/int2sym.pl -f 2- $data_dir/lang_word/words.txt | \
26 | sed 's:::g' | sed 's:::g' > $dec_data_dir/$x/text
27 | utils/fix_data_dir.sh $dec_data_dir/$x
28 | echo "WER on $x is" $(compute-wer ark:$wrd_data_dir/${x}_gt/text ark:$dec_data_dir/$x/text | cut -d" " -f2-)
29 | done
30 |
31 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | for idx, line in enumerate(sys.stdin):
4 | print(f"utt{idx:010d} {line}", end='')
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -u
4 |
5 | val_sets="dev_other"
6 | graph_name=graph
7 | decode_suffix=""
8 | decode_script="steps/decode_fmllr.sh"
9 | decode_args=""
10 | nj=60
11 |
12 | . ./cmd.sh
13 | . ./path.sh
14 | . parse_options.sh
15 |
16 | set -x
17 | exp_dir=$1
18 | data_root=$2
19 | lang_test=$3
20 |
21 | graph=$exp_dir/$graph_name
22 |
23 | if [ ! -d $graph ]; then
24 | utils/mkgraph.sh $lang_test $exp_dir $graph
25 | fi
26 |
27 | for part in $val_sets; do
28 | dec_dir=$exp_dir/decode${decode_suffix}_${part}
29 | if [ ! -d $dec_dir ]; then
30 | echo "decoding $part for $exp_dir"
31 | $decode_script --nj $nj --cmd "$decode_cmd" $decode_args \
32 | $graph $data_root/$part $dec_dir &
33 | else
34 | echo "$dec_dir exists. skip"
35 | fi
36 | done
37 |
38 | wait
39 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sil_prob=0.5
4 | num_sil_states=3
5 | num_nonsil_states=1
6 |
7 | . ./cmd.sh
8 | . ./path.sh
9 | . parse_options.sh
10 |
11 | set -eux
12 |
13 | dict=$1
14 | data_dir=$2
15 |
16 | dict_dir=$data_dir/local/dict
17 | tmplm_dir=$data_dir/local/lang_tmp
18 | lm_dir=$data_dir/lang
19 |
20 | mkdir -p $dict_dir $tmplm_dir $lm_dir
21 |
22 | # prepare dict
23 | echo "SIL" > $dict_dir/silence_phones.txt
24 | echo "SIL" > $dict_dir/optional_silence.txt
25 | awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt
26 |
27 | echo "SIL SIL" > $dict_dir/lexicon.txt
28 | echo " SIL" >> $dict_dir/lexicon.txt
29 | awk '{print $1" "$1}' $dict >> $dict_dir/lexicon.txt
30 |
31 | echo "SIL" > $dict_dir/extra_questions.txt
32 | awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt
33 |
34 | # prepare lang
35 | utils/prepare_lang.sh --sil-prob $sil_prob --position-dependent-phones false \
36 | --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \
37 | $dict_dir "" $tmplm_dir $lm_dir
38 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | num_sil_states=3
4 | num_nonsil_states=1
5 |
6 | . ./cmd.sh
7 | . ./path.sh
8 | . parse_options.sh
9 |
10 | set -eux
11 |
12 | dict=$1
13 | data_dir=$2
14 | lexicon=$3
15 |
16 | dict_dir=$data_dir/local/dict_word
17 | tmplm_dir=$data_dir/local/lang_tmp_word
18 | lm_dir=$data_dir/lang_word
19 |
20 | mkdir -p $dict_dir $tmplm_dir $lm_dir
21 |
22 | # prepare dict
23 | echo "SIL" > $dict_dir/silence_phones.txt
24 | echo "SIL" > $dict_dir/optional_silence.txt
25 | awk '{print $1}' $dict > $dict_dir/nonsilence_phones.txt
26 |
27 | (echo "!SIL SIL"; echo " SIL";) | cat - $lexicon > $dict_dir/lexicon.txt
28 |
29 | echo "SIL" > $dict_dir/extra_questions.txt
30 | awk '{printf $1" "} END {printf "\n"}' $dict >> $dict_dir/extra_questions.txt
31 |
32 | # prepare lang
33 | utils/prepare_lang.sh --position-dependent-phones false \
34 | --num_sil_states $num_sil_states --num_nonsil_states $num_nonsil_states \
35 | $dict_dir "" $tmplm_dir $lm_dir
36 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | langdir=""
4 | lmdir=""
5 |
6 | . ./cmd.sh
7 | . ./path.sh
8 | . parse_options.sh
9 |
10 | arpa_lm=$1
11 | data=$2
12 |
13 | if [ -z $langdir ]; then
14 | langdir=$data/lang
15 | fi
16 | if [ -z $lmdir ]; then
17 | lmdir=$data/lang_test
18 | fi
19 |
20 | if [ ! -d $langdir ]; then
21 | echo "$langdir not found. run local/prepare_lang.sh first" && exit 1
22 | fi
23 |
24 | mkdir -p $lmdir
25 | cp -r $langdir/* $lmdir
26 |
27 | if [[ "$arpa_lm" == *.gz ]]; then
28 | gunzip -c $arpa_lm | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt - $lmdir/G.fst
29 | else
30 | arpa2fst --disambig-symbol=#0 --read-symbol-table=$lmdir/words.txt $arpa_lm $lmdir/G.fst
31 | fi
32 | fstisstochastic $lmdir/G.fst
33 | utils/validate_lang.pl $lmdir || exit 1
34 |
35 | echo "done preparing lm ($lmdir)"
36 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split="dev_other"
4 | ref_data=""
5 | get_best_wer=true
6 | dec_name="decode"
7 | graph_name="graph"
8 |
9 | . ./cmd.sh
10 | . ./path.sh
11 | . parse_options.sh
12 |
13 | exp_root=$1
14 |
15 | set -eu
16 |
17 | echo "==== WER w.r.t. pseudo transcript"
18 | for x in $exp_root/*/${dec_name}_${split}*; do grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh; done
19 |
20 |
21 | if [ ! -z $ref_data ]; then
22 | echo "==== WER w.r.t. real transcript (select based on pseudo WER)"
23 | ref_txt=$ref_data/$split/text
24 | for x in $exp_root/*/${dec_name}_${split}*; do
25 | lang=$(dirname $x)/$graph_name
26 |
27 | lmwt=$(
28 | grep WER $x/wer_* 2>/dev/null | utils/best_wer.sh |
29 | sed 's/.*wer_\(.*\)$/\1/g' | sed 's/_/./g'
30 | )
31 | tra=$x/scoring/$lmwt.tra
32 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \
33 | compute-wer --text --mode=present \
34 | ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra
35 | done
36 | fi
37 |
38 | if [ ! -z $ref_data ] && $get_best_wer; then
39 | echo "==== WER w.r.t. real transcript (select based on true WER)"
40 | ref_txt=$ref_data/$split/text
41 | for x in $exp_root/*/${dec_name}_${split}*; do
42 | lang=$(dirname $x)/$graph_name
43 |
44 | for tra in $x/scoring/*.tra; do
45 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' | \
46 | compute-wer --text --mode=present \
47 | ark:$ref_txt ark,p:- 2> /dev/null | grep WER | xargs -I{} echo {} $tra
48 | done | sort -k2n | head -n1
49 | done
50 | fi
51 |
52 | exit 0;
53 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split="dev_other"
4 | ref_txt="" # ground truth transcript path
5 | psd_txt="" # pseudo transcript path
6 | get_best_wer=true
7 | dec_name="decode"
8 | graph_name="graph"
9 | kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin
10 |
11 | . ./cmd.sh
12 | . ./path.sh
13 | . parse_options.sh
14 |
15 | exp_root=$1
16 | unsup_args=""
17 | if [ $# -ge 2 ]; then
18 | unsup_args=$2
19 | fi
20 |
21 | set -eu
22 |
23 | if [ ! -z $ref_txt ] && $get_best_wer; then
24 | echo "==== WER w.r.t. real transcript (select based on unsupervised metric)"
25 | for x in $exp_root/*/${dec_name}_${split}*; do
26 | lang=$(dirname $x)/$graph_name
27 |
28 | (
29 | for tra in $x/scoring/*.tra; do
30 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:::g' | sed 's:::g' > $tra.txt
31 | python local/unsup_select.py $psd_txt $tra.txt --kenlm_path $kenlm_path --gt_tra $ref_txt $unsup_args
32 | done 2>/dev/null | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1
33 | ) &
34 | done
35 | fi
36 | wait
37 |
38 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split="dev_other"
4 | ref_txt="" # ground truth transcript path
5 | psd_txt="" # pseudo transcript path
6 | get_best_wer=true
7 | dec_name="decode"
8 | graph_name="graph"
9 | kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin
10 | phonemize_lexicon=""
11 |
12 | . ./cmd.sh
13 | . ./path.sh
14 | . parse_options.sh
15 | . /private/home/wnhsu/unsup_asr/fairseq-py-unsup/env.sh
16 |
17 | exp_root=$1
18 |
19 | set -eu
20 |
21 | if [ ! -z $ref_txt ] && $get_best_wer; then
22 | echo "==== WER w.r.t. real transcript (select based on unsupervised metric)"
23 | for x in $exp_root/*/${dec_name}_${split}*; do
24 | lang=$(dirname $x)/$graph_name
25 |
26 | for tra in $x/scoring/*.tra; do
27 | cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:\::g' > $tra.txt
28 | python local/unsup_select.py $psd_txt $tra.txt \
29 | --kenlm_path $kenlm_path --gt_tra $ref_txt --phonemize \
30 | --phonemize_lexicon "$phonemize_lexicon"
31 | done | grep "score=" | sed 's/=/ /g' | sed 's/;//g' | sort -k3n | head -n1
32 | done
33 | fi
34 |
35 |
36 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh:
--------------------------------------------------------------------------------
1 | export KALDI_ROOT=`pwd`/../../..
2 | export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
3 | [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
4 | . $KALDI_ROOT/tools/config/common_path.sh
5 | export LC_ALL=C
6 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/steps:
--------------------------------------------------------------------------------
1 | ../../wsj/s5/steps
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -eu
4 |
5 | w2v_dir= # contains features `{train,valid}.{npy,lengths}`, real transcripts `{train,valid}.${label}`, and dict `dict.${label}.txt`
6 | lab_dir= # contains pseudo labels `{train,valid}.txt`
7 | out_dir= # output root
8 | arpa_lm= # phone LM
9 | arpa_lm_bin= # (binary) phone LM for KenLM, used in unsupervised selection
10 |
11 | label=phnc
12 | train_name="train"
13 | valid_name="valid"
14 | data_dir=${out_dir}/data
15 |
16 | mkdir -p ${out_dir}/exp
17 | local/prepare_lang.sh $w2v_dir/dict.${label}.txt $data_dir
18 | local/prepare_lm.sh $arpa_lm $data_dir
19 |
20 | for x in $train_name $valid_name; do
21 | x_gt=${x}_gt
22 |
23 | # prepare pseudo data
24 | python local/prepare_data_from_w2v.py $w2v_dir $data_dir $x
25 | steps/compute_cmvn_stats.sh $data_dir/$x $out_dir/exp/make_feat/$x $out_dir/feats/$x
26 | python local/copy_aligned_text.py < $lab_dir/$x.txt > $data_dir/$x/text
27 |
28 | # prepare ground truth data
29 | mkdir $data_dir/$x_gt
30 | cp $data_dir/$x/{feats.scp,cmvn.scp,utt2spk,spk2utt} $data_dir/$x_gt/
31 | python local/copy_aligned_text.py < $w2v_dir/$x.$label > $data_dir/$x_gt/text
32 | done
33 |
34 | local/train_subset_lgbeam.sh \
35 | --out_root ${out_dir} --out_name exp --train $train_name --valid $valid_name \
36 | --mono_size 2000 --tri1_size 5000 --tri2b_size -1 --tri3b_size -1 \
37 | --stage 1 --max_stage 3 $data_dir $data_dir/lang $data_dir/lang_test
38 |
39 | local/unsup_select_decode.sh \
40 | --split $valid_name --kenlm_path $arpa_lm_bin \
41 | --ref_txt $data_dir/${valid_name}_gt/text \
42 | --psd_txt $data_dir/${valid_name}/text \
43 | $out_dir/exp
44 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/kaldi_self_train/st/utils:
--------------------------------------------------------------------------------
1 | ../../wsj/s5/utils
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .wav2vec_u import Wav2vec_U
7 |
8 |
9 | __all__ = [
10 | "Wav2vec_U",
11 | ]
12 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/copy_labels.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 | for idx, line in enumerate(sys.stdin):
10 | print(f"utt{idx:010d} {line}", end="")
11 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/filter_lexicon.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import sys
9 |
10 | from fairseq.data import Dictionary
11 |
12 |
13 | def get_parser():
14 | parser = argparse.ArgumentParser(
15 | description="filters a lexicon given a unit dictionary"
16 | )
17 | parser.add_argument("-d", "--unit-dict", help="unit dictionary", required=True)
18 | return parser
19 |
20 |
21 | def main():
22 | parser = get_parser()
23 | args = parser.parse_args()
24 |
25 | d = Dictionary.load(args.unit_dict)
26 | symbols = set(d.symbols)
27 |
28 | for line in sys.stdin:
29 | items = line.rstrip().split()
30 | skip = len(items) < 2
31 | for x in items[1:]:
32 | if x not in symbols:
33 | skip = True
34 | break
35 | if not skip:
36 | print(line, end="")
37 |
38 |
39 | if __name__ == "__main__":
40 | main()
41 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/filter_tsv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import argparse
9 | import sys
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--tsv", required=True, type=str)
14 | parser.add_argument("--no-skip", action="store_true")
15 | parser.add_argument("--keep", action="store_true")
16 | params = parser.parse_args()
17 |
18 |
19 | def get_fname(line):
20 | p = os.path.basename(line.split("\t")[0])
21 | p = os.path.splitext(p)[0]
22 | return p
23 |
24 |
25 | # filenames to exclude
26 | seen = set()
27 | with open(params.tsv) as f:
28 | if not params.no_skip:
29 | root = next(f).rstrip()
30 | for line in f:
31 | seen.add(get_fname(line))
32 |
33 | for i, line in enumerate(sys.stdin):
34 | exists = get_fname(line) in seen
35 | keep = (exists and params.keep) or (not exists and not params.keep)
36 | if i == 0 or keep:
37 | print(line, end="")
38 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import sys
9 |
10 | from g2p_en import G2p
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--compact",
17 | action="store_true",
18 | help="if set, compacts phones",
19 | )
20 | args = parser.parse_args()
21 |
22 | compact = args.compact
23 |
24 | wrd_to_phn = {}
25 | g2p = G2p()
26 | for line in sys.stdin:
27 | words = line.strip().split()
28 | phones = []
29 | for w in words:
30 | if w not in wrd_to_phn:
31 | wrd_to_phn[w] = g2p(w)
32 | if compact:
33 | wrd_to_phn[w] = [
34 | p[:-1] if p[-1].isnumeric() else p for p in wrd_to_phn[w]
35 | ]
36 | phones.extend(wrd_to_phn[w])
37 | try:
38 | print(" ".join(phones))
39 | except:
40 | print(wrd_to_phn, words, phones, file=sys.stderr)
41 | raise
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 |
10 | def main():
11 | for line in sys.stdin:
12 | print(line.replace(" ", "").replace("|", " ").strip())
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/normalize_text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import regex
8 | import sys
9 |
10 |
11 | def main():
12 | filter_r = regex.compile(r"[^\p{L}\p{N}\p{M}\' \-]")
13 |
14 | for line in sys.stdin:
15 | line = line.strip()
16 | line = filter_r.sub(" ", line)
17 | line = " ".join(line.split())
18 | print(line)
19 |
20 |
21 | if __name__ == "__main__":
22 | main()
23 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/pca.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 |
12 | import faiss
13 |
14 |
15 |
16 | def get_parser():
17 | parser = argparse.ArgumentParser(
18 | description="compute a pca matrix given an array of numpy features"
19 | )
20 | # fmt: off
21 | parser.add_argument('data', help='numpy file containing features')
22 | parser.add_argument('--output', help='where to save the pca matrix', required=True)
23 | parser.add_argument('--dim', type=int, help='dim for pca reduction', required=True)
24 | parser.add_argument('--eigen-power', type=float, default=0, help='eigen power, -0.5 for whitening')
25 |
26 | return parser
27 |
28 |
29 | def main():
30 | parser = get_parser()
31 | args = parser.parse_args()
32 |
33 | print("Reading features")
34 | x = np.load(args.data, mmap_mode="r")
35 |
36 | print("Computing PCA")
37 | pca = faiss.PCAMatrix(x.shape[-1], args.dim, args.eigen_power)
38 | pca.train(x)
39 | b = faiss.vector_to_array(pca.b)
40 | A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
41 |
42 | os.makedirs(args.output, exist_ok=True)
43 |
44 | prefix = str(args.dim)
45 | if args.eigen_power != 0:
46 | prefix += f"_{args.eigen_power}"
47 |
48 | np.save(osp.join(args.output, f"{prefix}_pca_A"), A.T)
49 | np.save(osp.join(args.output, f"{prefix}_pca_b"), b)
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 |
10 | def main():
11 | for line in sys.stdin:
12 | print(" ".join(list(line.strip().replace(" ", "|"))) + " |")
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/examples/wav2vec/unsupervised/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .unpaired_audio_text import UnpairedAudioText
7 |
8 |
9 | __all__ = [
10 | "UnpairedAudioText",
11 | ]
12 |
--------------------------------------------------------------------------------
/examples/wav2vec/xlsr/config/finetune.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | common:
4 | fp16: true
5 | log_format: json
6 | log_interval: 200
7 | tensorboard_logdir: tb
8 |
9 | checkpoint:
10 | save_interval: 1000
11 | save_interval_updates: 1000
12 | keep_interval_updates: 1
13 | no_epoch_checkpoints: true
14 | best_checkpoint_metric: wer
15 |
16 | task:
17 | _name: audio_finetuning
18 | data: ???
19 | normalize: true
20 | labels: ltr
21 |
22 | dataset:
23 | num_workers: 6
24 | max_tokens: 1280000
25 | skip_invalid_size_inputs_valid_test: true
26 | validate_after_updates: 10000
27 | validate_interval_updates: 1000
28 | valid_subset: valid
29 |
30 | distributed_training:
31 | ddp_backend: legacy_ddp
32 | distributed_world_size: 4
33 |
34 | criterion:
35 | _name: ctc
36 | zero_infinity: true
37 |
38 | optimization:
39 | max_update: ???
40 | lr: [0.0003]
41 | sentence_avg: true
42 | update_freq: [5]
43 |
44 | optimizer:
45 | _name: adam
46 | adam_betas: (0.9,0.98)
47 | adam_eps: 1e-08
48 |
49 | lr_scheduler:
50 | _name: tri_stage
51 | phase_ratio: [0.1, 0.4, 0.5]
52 | final_lr_scale: 0.05
53 |
54 | model:
55 | _name: wav2vec_ctc
56 | w2v_path: ???
57 | apply_mask: true
58 | mask_prob: 0.75
59 | mask_channel_prob: 0.25
60 | mask_channel_length: 64
61 | layerdrop: 0.1
62 | activation_dropout: 0.1
63 | feature_grad_mult: 0.0
64 | freeze_finetune_updates: 10000
65 |
66 | checkpoint_activations: false
67 |
--------------------------------------------------------------------------------
/fairseq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | import os
8 | import sys
9 |
10 | try:
11 | from .version import __version__ # noqa
12 | except ImportError:
13 | version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
14 | with open(version_txt) as f:
15 | __version__ = f.read().strip()
16 |
17 | __all__ = ["pdb"]
18 |
19 | # backwards compatibility to support `from fairseq.X import Y`
20 | from fairseq.distributed import utils as distributed_utils
21 | from fairseq.logging import meters, metrics, progress_bar # noqa
22 |
23 | sys.modules["fairseq.distributed_utils"] = distributed_utils
24 | sys.modules["fairseq.meters"] = meters
25 | sys.modules["fairseq.metrics"] = metrics
26 | sys.modules["fairseq.progress_bar"] = progress_bar
27 |
28 | # initialize hydra
29 | from fairseq.dataclass.initialize import hydra_init
30 |
31 | hydra_init()
32 |
33 | import fairseq.criterions # noqa
34 | import fairseq.distributed # noqa
35 | import fairseq.models # noqa
36 | import fairseq.modules # noqa
37 | import fairseq.optim # noqa
38 | import fairseq.optim.lr_scheduler # noqa
39 | import fairseq.pdb # noqa
40 | import fairseq.scoring # noqa
41 | import fairseq.tasks # noqa
42 | import fairseq.token_generation_constraints # noqa
43 |
44 | import fairseq.benchmark # noqa
45 | import fairseq.model_parallel # noqa
46 |
--------------------------------------------------------------------------------
/fairseq/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # import models/tasks to register them
7 | from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/benchmark/dummy_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from fairseq.data import FairseqDataset
3 |
4 |
5 | class DummyDataset(FairseqDataset):
6 | def __init__(self, batch, num_items, item_size):
7 | super().__init__()
8 | self.batch = batch
9 | self.num_items = num_items
10 | self.item_size = item_size
11 |
12 | def __getitem__(self, index):
13 | return index
14 |
15 | def __len__(self):
16 | return self.num_items
17 |
18 | def collater(self, samples):
19 | return self.batch
20 |
21 | @property
22 | def sizes(self):
23 | return np.array([self.item_size] * self.num_items)
24 |
25 | def num_tokens(self, index):
26 | return self.item_size
27 |
28 | def size(self, index):
29 | return self.item_size
30 |
31 | def ordered_indices(self):
32 | return np.arange(self.num_items)
33 |
34 | @property
35 | def supports_prefetch(self):
36 | return False
37 |
--------------------------------------------------------------------------------
/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT License.
4 | */
5 |
6 | #include
7 | #include
8 |
9 | /*
10 | CPP Binding for CUDA OP
11 | */
12 |
13 | // CUDA forward declarations
14 | torch::Tensor ngram_repeat_block_cuda_forward(
15 | torch::Tensor tokens,
16 | torch::Tensor lprobs,
17 | int bsz,
18 | int step,
19 | int beam_size,
20 | int no_repeat_ngram_size);
21 |
22 | #define CHECK_CUDA(x) \
23 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) \
25 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
26 | #define CHECK_INPUT(x) \
27 | CHECK_CUDA(x); \
28 | CHECK_CONTIGUOUS(x)
29 |
30 | // Input check and call to CUDA OP
31 | // Backward method not required
32 | torch::Tensor ngram_repeat_block_forward(
33 | torch::Tensor tokens,
34 | torch::Tensor lprobs,
35 | int bsz,
36 | int step,
37 | int beam_size,
38 | int no_repeat_ngram_size) {
39 | CHECK_INPUT(tokens);
40 | CHECK_INPUT(lprobs);
41 | assert(bsz > 0);
42 | assert(step >= 0);
43 | assert(beam_size > 0);
44 | assert(no_repeat_ngram_size > 0);
45 |
46 | return ngram_repeat_block_cuda_forward(
47 | tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
48 | }
49 |
50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
51 | m.def(
52 | "forward",
53 | &ngram_repeat_block_forward,
54 | "No Repeat Ngram Block forward (CUDA)");
55 | }
56 |
--------------------------------------------------------------------------------
/fairseq/clib/libbleu/module.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2017-present, Facebook, Inc.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #include
10 |
11 | static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
12 |
13 | static struct PyModuleDef module_def = {
14 | PyModuleDef_HEAD_INIT,
15 | "libbleu", /* name of module */
16 | // NOLINTNEXTLINE
17 | NULL, /* module documentation, may be NULL */
18 | -1, /* size of per-interpreter state of the module,
19 | or -1 if the module keeps state in global variables. */
20 | method_def}; // NOLINT
21 |
22 | #if PY_MAJOR_VERSION == 2
23 | PyMODINIT_FUNC init_libbleu()
24 | #else
25 | PyMODINIT_FUNC PyInit_libbleu()
26 | #endif
27 | {
28 | PyObject* m = PyModule_Create(&module_def);
29 | if (!m) {
30 | return NULL;
31 | }
32 | return m;
33 | }
34 |
--------------------------------------------------------------------------------
/fairseq/clib/libnat_cuda/edit_dist.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2017-present, Facebook, Inc.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #pragma once
10 |
11 | #include
12 |
13 | torch::Tensor LevenshteinDistanceCuda(
14 | torch::Tensor source,
15 | torch::Tensor target,
16 | torch::Tensor source_length,
17 | torch::Tensor target_length);
18 |
19 | torch::Tensor GenerateDeletionLabelCuda(
20 | torch::Tensor source,
21 | torch::Tensor operations);
22 |
23 | std::pair GenerateInsertionLabelCuda(
24 | torch::Tensor source,
25 | torch::Tensor operations);
26 |
--------------------------------------------------------------------------------
/fairseq/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/fairseq/config/config.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | hydra:
4 | run:
5 | dir: .
6 |
7 | defaults:
8 | - _self_
9 | - task: null
10 | - model: null
11 | - criterion: cross_entropy
12 | - optimizer: null
13 | - lr_scheduler: fixed
14 | - bpe: null
15 | - tokenizer: null
16 | - scoring: null
17 | - generation: null
18 | - common_eval: null
19 | - eval_lm: null
20 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "relu"
3 | dropout: 0.1
4 | attention_dropout: 0.1
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 512
8 | decoder_output_dim: 512
9 | decoder_input_dim: 512
10 | decoder_ffn_embed_dim: 4096
11 | decoder_layers: 12
12 | decoder_attention_heads: 16
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: true
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "relu"
3 | dropout: 0.3
4 | attention_dropout: 0.1
5 | activation_dropout: 0.1
6 | relu_dropout: 0.1
7 | decoder_embed_dim: 1024
8 | decoder_output_dim: 1024
9 | decoder_input_dim: 1024
10 | decoder_ffn_embed_dim: 4096
11 | decoder_layers: 16
12 | decoder_attention_heads: 8
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: true
15 | adaptive_softmax_cutoff: "20000,60000"
16 | adaptive_softmax_dropout: 0.2
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: true
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: "20000,60000"
27 | tie_adaptive_weights: true
28 | tie_adaptive_proj: true
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_big.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "relu"
3 | dropout: 0.1
4 | attention_dropout: 0.0
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 1024
8 | decoder_output_dim: 1024
9 | decoder_input_dim: 1024
10 | decoder_ffn_embed_dim: 4096
11 | decoder_layers: 12
12 | decoder_attention_heads: 16
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: false
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "relu"
3 | dropout: 0.1
4 | attention_dropout: 0.1
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 512
8 | decoder_output_dim: 512
9 | decoder_input_dim: 512
10 | decoder_ffn_embed_dim: 4096
11 | decoder_layers: 12
12 | decoder_attention_heads: 16
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: true
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "gelu"
3 | dropout: 0.1
4 | attention_dropout: 0.1
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 768
8 | decoder_output_dim: 768
9 | decoder_input_dim: 768
10 | decoder_ffn_embed_dim: 3072
11 | decoder_layers: 12
12 | decoder_attention_heads: 12
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: false
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "gelu"
3 | dropout: 0.1
4 | attention_dropout: 0.1
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 1600
8 | decoder_output_dim: 1600
9 | decoder_input_dim: 1600
10 | decoder_ffn_embed_dim: 6400
11 | decoder_layers: 48
12 | decoder_attention_heads: 25
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: false
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "gelu"
3 | dropout: 0.1
4 | attention_dropout: 0.1
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 1280
8 | decoder_output_dim: 1280
9 | decoder_input_dim: 1280
10 | decoder_ffn_embed_dim: 5120
11 | decoder_layers: 36
12 | decoder_attention_heads: 20
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: false
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "gelu"
3 | dropout: 0.1
4 | attention_dropout: 0.1
5 | activation_dropout: 0.0
6 | relu_dropout: 0.0
7 | decoder_embed_dim: 1024
8 | decoder_output_dim: 1024
9 | decoder_input_dim: 1024
10 | decoder_ffn_embed_dim: 4096
11 | decoder_layers: 24
12 | decoder_attention_heads: 16
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: false
15 | adaptive_softmax_cutoff: null
16 | adaptive_softmax_dropout: 0
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: false
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: null
27 | tie_adaptive_weights: false
28 | tie_adaptive_proj: false
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation_fn: "relu"
3 | dropout: 0.3
4 | attention_dropout: 0.1
5 | activation_dropout: 0.1
6 | relu_dropout: 0.1
7 | decoder_embed_dim: 1024
8 | decoder_output_dim: 1024
9 | decoder_input_dim: 1024
10 | decoder_ffn_embed_dim: 4096
11 | decoder_layers: 16
12 | decoder_attention_heads: 8
13 | decoder_normalize_before: true
14 | no_decoder_final_norm: true
15 | adaptive_softmax_cutoff: "20000,60000"
16 | adaptive_softmax_dropout: 0.2
17 | adaptive_softmax_factor: 4
18 | no_token_positional_embeddings: false
19 | share_decoder_input_output_embed: false
20 | character_embeddings: false
21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22 | character_embedding_dim: 4
23 | char_embedder_highway_layers: 2
24 | adaptive_input: true
25 | adaptive_input_factor: 4
26 | adaptive_input_cutoff: "20000,60000"
27 | tie_adaptive_weights: true
28 | tie_adaptive_proj: true
29 | decoder_learned_pos: false
30 | decoder_layerdrop: 0
31 | decoder_layers_to_keep: null
32 | layernorm_embedding: false
33 | no_scale_embedding: false
34 | quant_noise_pq: 0
35 | quant_noise_pq_block_size: 8
36 | quant_noise_scalar: 0
37 |
--------------------------------------------------------------------------------
/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | activation: gelu
3 | vq_type: gumbel
4 | vq_depth: 2
5 | combine_groups: true
6 |
--------------------------------------------------------------------------------
/fairseq/config/model/wav2vec2/wav2vec2_base.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | quantize_targets: true
4 | final_dim: 256
5 | encoder_layerdrop: 0.05
6 | dropout_input: 0.1
7 | dropout_features: 0.1
8 | feature_grad_mult: 0.1
9 |
--------------------------------------------------------------------------------
/fairseq/config/model/wav2vec2/wav2vec2_large.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | quantize_targets: true
4 | extractor_mode: layer_norm
5 | layer_norm_first: true
6 | final_dim: 768
7 | latent_temp: [2.0,0.1,0.999995]
8 | encoder_layerdrop: 0.0
9 | dropout_input: 0.0
10 | dropout_features: 0.0
11 | dropout: 0.0
12 | attention_dropout: 0.0
13 | conv_bias: true
14 |
15 | encoder_layers: 24
16 | encoder_embed_dim: 1024
17 | encoder_ffn_embed_dim: 4096
18 | encoder_attention_heads: 16
19 |
20 | feature_grad_mult: 1.0
21 |
--------------------------------------------------------------------------------
/fairseq/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | import importlib
8 | import os
9 |
10 | from fairseq import registry
11 | from fairseq.criterions.fairseq_criterion import ( # noqa
12 | FairseqCriterion,
13 | LegacyFairseqCriterion,
14 | )
15 | from omegaconf import DictConfig
16 |
17 |
18 | (
19 | build_criterion_,
20 | register_criterion,
21 | CRITERION_REGISTRY,
22 | CRITERION_DATACLASS_REGISTRY,
23 | ) = registry.setup_registry(
24 | "--criterion", base_class=FairseqCriterion, default="cross_entropy"
25 | )
26 |
27 |
28 | def build_criterion(cfg: DictConfig, task):
29 | return build_criterion_(cfg, task)
30 |
31 |
32 | # automatically import any Python files in the criterions/ directory
33 | for file in sorted(os.listdir(os.path.dirname(__file__))):
34 | if file.endswith(".py") and not file.startswith("_"):
35 | file_name = file[: file.find(".py")]
36 | importlib.import_module("fairseq.criterions." + file_name)
37 |
--------------------------------------------------------------------------------
/fairseq/data/append_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class AppendTokenDataset(BaseWrapperDataset):
13 | def __init__(self, dataset, token=None):
14 | super().__init__(dataset)
15 | self.token = token
16 | if token is not None:
17 | self._sizes = np.array(dataset.sizes) + 1
18 | else:
19 | self._sizes = dataset.sizes
20 |
21 | def __getitem__(self, idx):
22 | item = self.dataset[idx]
23 | if self.token is not None:
24 | item = torch.cat([item, item.new([self.token])])
25 | return item
26 |
27 | @property
28 | def sizes(self):
29 | return self._sizes
30 |
31 | def num_tokens(self, index):
32 | n = self.dataset.num_tokens(index)
33 | if self.token is not None:
34 | n += 1
35 | return n
36 |
37 | def size(self, index):
38 | n = self.dataset.size(index)
39 | if self.token is not None:
40 | n += 1
41 | return n
42 |
--------------------------------------------------------------------------------
/fairseq/data/audio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/fairseq/data/audio/__init__.py
--------------------------------------------------------------------------------
/fairseq/data/audio/feature_transforms/delta_deltas.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from fairseq.data.audio.feature_transforms import (
4 | AudioFeatureTransform,
5 | register_audio_feature_transform,
6 | )
7 |
8 |
9 | @register_audio_feature_transform("delta_deltas")
10 | class DeltaDeltas(AudioFeatureTransform):
11 | """Expand delta-deltas features from spectrum."""
12 |
13 | @classmethod
14 | def from_config_dict(cls, config=None):
15 | _config = {} if config is None else config
16 | return DeltaDeltas(_config.get("win_length", 5))
17 |
18 | def __init__(self, win_length=5):
19 | self.win_length = win_length
20 |
21 | def __repr__(self):
22 | return self.__class__.__name__
23 |
24 | def __call__(self, spectrogram):
25 | from torchaudio.functional import compute_deltas
26 |
27 | assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
28 | # spectrogram is T x F, while compute_deltas takes (…, F, T)
29 | spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
30 | delta = compute_deltas(spectrogram)
31 | delta_delta = compute_deltas(delta)
32 |
33 | out_feat = np.concatenate(
34 | [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
35 | )
36 | out_feat = np.transpose(out_feat)
37 | return out_feat
38 |
--------------------------------------------------------------------------------
/fairseq/data/audio/feature_transforms/global_cmvn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from fairseq.data.audio.feature_transforms import (
3 | AudioFeatureTransform,
4 | register_audio_feature_transform,
5 | )
6 |
7 |
8 | @register_audio_feature_transform("global_cmvn")
9 | class GlobalCMVN(AudioFeatureTransform):
10 | """Global CMVN (cepstral mean and variance normalization). The global mean
11 | and variance need to be pre-computed and stored in NumPy format (.npz)."""
12 |
13 | @classmethod
14 | def from_config_dict(cls, config=None):
15 | _config = {} if config is None else config
16 | return GlobalCMVN(_config.get("stats_npz_path"))
17 |
18 | def __init__(self, stats_npz_path):
19 | self.stats_npz_path = stats_npz_path
20 | stats = np.load(stats_npz_path)
21 | self.mean, self.std = stats["mean"], stats["std"]
22 |
23 | def __repr__(self):
24 | return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
25 |
26 | def __call__(self, x):
27 | x = np.subtract(x, self.mean)
28 | x = np.divide(x, self.std)
29 | return x
30 |
--------------------------------------------------------------------------------
/fairseq/data/audio/feature_transforms/utterance_cmvn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from fairseq.data.audio.feature_transforms import (
4 | AudioFeatureTransform,
5 | register_audio_feature_transform,
6 | )
7 |
8 |
9 | @register_audio_feature_transform("utterance_cmvn")
10 | class UtteranceCMVN(AudioFeatureTransform):
11 | """Utterance-level CMVN (cepstral mean and variance normalization)"""
12 |
13 | @classmethod
14 | def from_config_dict(cls, config=None):
15 | _config = {} if config is None else config
16 | return UtteranceCMVN(
17 | _config.get("norm_means", True),
18 | _config.get("norm_vars", True),
19 | )
20 |
21 | def __init__(self, norm_means=True, norm_vars=True):
22 | self.norm_means, self.norm_vars = norm_means, norm_vars
23 |
24 | def __repr__(self):
25 | return (
26 | self.__class__.__name__
27 | + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
28 | )
29 |
30 | def __call__(self, x):
31 | mean = x.mean(axis=0)
32 | square_sums = (x**2).sum(axis=0)
33 |
34 | if self.norm_means:
35 | x = np.subtract(x, mean)
36 | if self.norm_vars:
37 | var = square_sums / x.shape[0] - mean**2
38 | std = np.sqrt(np.maximum(var, 1e-10))
39 | x = np.divide(x, std)
40 |
41 | return x
42 |
--------------------------------------------------------------------------------
/fairseq/data/colorize_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class ColorizeDataset(BaseWrapperDataset):
12 | """Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
13 |
14 | def __init__(self, dataset, color_getter):
15 | super().__init__(dataset)
16 | self.color_getter = color_getter
17 |
18 | def collater(self, samples):
19 | base_collate = super().collater(samples)
20 | if len(base_collate) > 0:
21 | base_collate["net_input"]["colors"] = torch.tensor(
22 | list(self.color_getter(self.dataset, s["id"]) for s in samples),
23 | dtype=torch.long,
24 | )
25 | return base_collate
26 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import importlib
8 | import os
9 |
10 | from fairseq import registry
11 |
12 |
13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
14 | "--tokenizer",
15 | default=None,
16 | )
17 |
18 |
19 | build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
20 | "--bpe",
21 | default=None,
22 | )
23 |
24 |
25 | # automatically import any Python files in the encoders/ directory
26 | for file in sorted(os.listdir(os.path.dirname(__file__))):
27 | if file.endswith(".py") and not file.startswith("_"):
28 | module = file[: file.find(".py")]
29 | importlib.import_module("fairseq.data.encoders." + module)
30 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/byte_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from dataclasses import dataclass, field
8 |
9 | from fairseq import file_utils
10 | from fairseq.data.encoders import register_bpe
11 | from fairseq.data.encoders.byte_utils import (
12 | SPACE,
13 | SPACE_ESCAPE,
14 | byte_encode,
15 | smart_byte_decode,
16 | )
17 | from fairseq.dataclass import FairseqDataclass
18 |
19 |
20 | @dataclass
21 | class ByteBpeConfig(FairseqDataclass):
22 | sentencepiece_model_path: str = field(
23 | default="???", metadata={"help": "path to sentencepiece model"}
24 | )
25 |
26 |
27 | @register_bpe("byte_bpe", dataclass=ByteBpeConfig)
28 | class ByteBPE(object):
29 | def __init__(self, cfg):
30 | vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
31 | try:
32 | import sentencepiece as spm
33 |
34 | self.sp = spm.SentencePieceProcessor()
35 | self.sp.Load(vocab)
36 | except ImportError:
37 | raise ImportError(
38 | "Please install sentencepiece with: pip install sentencepiece"
39 | )
40 |
41 | def encode(self, x: str) -> str:
42 | byte_encoded = byte_encode(x)
43 | return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
44 |
45 | @staticmethod
46 | def decode(x: str) -> str:
47 | unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
48 | return smart_byte_decode(unescaped)
49 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/bytes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from fairseq.data.encoders import register_bpe
8 | from fairseq.data.encoders.byte_utils import (
9 | SPACE,
10 | SPACE_ESCAPE,
11 | byte_encode,
12 | smart_byte_decode,
13 | )
14 |
15 |
16 | @register_bpe("bytes")
17 | class Bytes(object):
18 | def __init__(self, *unused):
19 | pass
20 |
21 | @staticmethod
22 | def add_args(parser):
23 | pass
24 |
25 | @staticmethod
26 | def encode(x: str) -> str:
27 | encoded = byte_encode(x)
28 | escaped = encoded.replace(SPACE, SPACE_ESCAPE)
29 | return SPACE.join(list(escaped))
30 |
31 | @staticmethod
32 | def decode(x: str) -> str:
33 | unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
34 | return smart_byte_decode(unescaped)
35 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/characters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from fairseq.data.encoders import register_bpe
8 |
9 |
10 | SPACE = chr(32)
11 | SPACE_ESCAPE = chr(9601)
12 |
13 |
14 | @register_bpe("characters")
15 | class Characters(object):
16 | def __init__(self, *unused):
17 | pass
18 |
19 | @staticmethod
20 | def add_args(parser):
21 | pass
22 |
23 | @staticmethod
24 | def encode(x: str) -> str:
25 | escaped = x.replace(SPACE, SPACE_ESCAPE)
26 | return SPACE.join(list(escaped))
27 |
28 | @staticmethod
29 | def decode(x: str) -> str:
30 | return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
31 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/fastbpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass, field
7 |
8 | from fairseq import file_utils
9 | from fairseq.data.encoders import register_bpe
10 | from fairseq.dataclass import FairseqDataclass
11 |
12 |
13 | @dataclass
14 | class fastBPEConfig(FairseqDataclass):
15 | bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
16 |
17 |
18 | @register_bpe("fastbpe", dataclass=fastBPEConfig)
19 | class fastBPE(object):
20 | def __init__(self, cfg):
21 | if cfg.bpe_codes is None:
22 | raise ValueError("--bpe-codes is required for --bpe=fastbpe")
23 | codes = file_utils.cached_path(cfg.bpe_codes)
24 | try:
25 | import fastBPE
26 |
27 | self.bpe = fastBPE.fastBPE(codes)
28 | self.bpe_symbol = "@@ "
29 | except ImportError:
30 | raise ImportError("Please install fastBPE with: pip install fastBPE")
31 |
32 | def encode(self, x: str) -> str:
33 | return self.bpe.apply([x])[0]
34 |
35 | def decode(self, x: str) -> str:
36 | return (x + " ").replace(self.bpe_symbol, "").rstrip()
37 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/gpt2_bpe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass, field
7 |
8 | from fairseq import file_utils
9 | from fairseq.data.encoders import register_bpe
10 | from fairseq.dataclass import FairseqDataclass
11 |
12 | from .gpt2_bpe_utils import get_encoder
13 |
14 |
15 | DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
16 | DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
17 |
18 |
19 | @dataclass
20 | class GPT2BPEConfig(FairseqDataclass):
21 | gpt2_encoder_json: str = field(
22 | default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
23 | )
24 | gpt2_vocab_bpe: str = field(
25 | default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
26 | )
27 |
28 |
29 | @register_bpe("gpt2", dataclass=GPT2BPEConfig)
30 | class GPT2BPE(object):
31 | def __init__(self, cfg):
32 | encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
33 | vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
34 | self.bpe = get_encoder(encoder_json, vocab_bpe)
35 |
36 | def encode(self, x: str) -> str:
37 | return " ".join(map(str, self.bpe.encode(x)))
38 |
39 | def decode(self, x: str) -> str:
40 | return self.bpe.decode(
41 | [int(tok) if tok not in {"", ""} else tok for tok in x.split()]
42 | )
43 |
44 | def is_beginning_of_word(self, x: str) -> bool:
45 | return self.decode(x).startswith(" ")
46 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/nltk_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data.encoders import register_tokenizer
7 | from fairseq.dataclass import FairseqDataclass
8 |
9 |
10 | @register_tokenizer("nltk", dataclass=FairseqDataclass)
11 | class NLTKTokenizer(object):
12 | def __init__(self, *unused):
13 | try:
14 | from nltk.tokenize import word_tokenize
15 |
16 | self.word_tokenize = word_tokenize
17 | except ImportError:
18 | raise ImportError("Please install nltk with: pip install nltk")
19 |
20 | def encode(self, x: str) -> str:
21 | return " ".join(self.word_tokenize(x))
22 |
23 | def decode(self, x: str) -> str:
24 | return x
25 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/space_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import re
7 |
8 | from fairseq.data.encoders import register_tokenizer
9 | from fairseq.dataclass import FairseqDataclass
10 |
11 |
12 | @register_tokenizer("space", dataclass=FairseqDataclass)
13 | class SpaceTokenizer(object):
14 | def __init__(self, *unused):
15 | self.space_tok = re.compile(r"\s+")
16 |
17 | def encode(self, x: str) -> str:
18 | return self.space_tok.sub(" ", x)
19 |
20 | def decode(self, x: str) -> str:
21 | return x
22 |
--------------------------------------------------------------------------------
/fairseq/data/encoders/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from fairseq.data import encoders
8 |
9 |
10 | def get_whole_word_mask(args, dictionary):
11 | bpe = encoders.build_bpe(args)
12 | if bpe is not None:
13 |
14 | def is_beginning_of_word(i):
15 | if i < dictionary.nspecial:
16 | # special elements are always considered beginnings
17 | return True
18 | tok = dictionary[i]
19 | if tok.startswith("madeupword"):
20 | return True
21 | try:
22 | return bpe.is_beginning_of_word(tok)
23 | except ValueError:
24 | return True
25 |
26 | mask_whole_words = torch.ByteTensor(
27 | list(map(is_beginning_of_word, range(len(dictionary))))
28 | )
29 | return mask_whole_words
30 | return None
31 |
--------------------------------------------------------------------------------
/fairseq/data/huffman/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
7 | from .huffman_mmap_indexed_dataset import (
8 | HuffmanMMapIndex,
9 | HuffmanMMapIndexedDataset,
10 | HuffmanMMapIndexedDatasetBuilder,
11 | vocab_file_path,
12 | )
13 |
14 | __all__ = [
15 | "HuffmanCoder",
16 | "HuffmanCodeBuilder",
17 | "HuffmanMMapIndexedDatasetBuilder",
18 | "HuffmanMMapIndexedDataset",
19 | "HuffmanMMapIndex",
20 | "vocab_file_path",
21 | ]
22 |
--------------------------------------------------------------------------------
/fairseq/data/id_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class IdDataset(FairseqDataset):
12 | def __getitem__(self, index):
13 | return index
14 |
15 | def __len__(self):
16 | return 0
17 |
18 | def collater(self, samples):
19 | return torch.tensor(samples)
20 |
--------------------------------------------------------------------------------
/fairseq/data/legacy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .block_pair_dataset import BlockPairDataset
7 | from .masked_lm_dataset import MaskedLMDataset
8 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
9 |
10 |
11 | __all__ = [
12 | "BertDictionary",
13 | "BlockPairDataset",
14 | "MaskedLMDataset",
15 | "MaskedLMDictionary",
16 | ]
17 |
--------------------------------------------------------------------------------
/fairseq/data/list_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class ListDataset(BaseWrapperDataset):
10 | def __init__(self, dataset, sizes=None):
11 | super().__init__(dataset)
12 | self._sizes = sizes
13 |
14 | def __iter__(self):
15 | for x in self.dataset:
16 | yield x
17 |
18 | def collater(self, samples):
19 | return samples
20 |
21 | @property
22 | def sizes(self):
23 | return self._sizes
24 |
25 | def num_tokens(self, index):
26 | return self.sizes[index]
27 |
28 | def size(self, index):
29 | return self.sizes[index]
30 |
31 | def set_epoch(self, epoch):
32 | pass
33 |
--------------------------------------------------------------------------------
/fairseq/data/lru_cache_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from functools import lru_cache
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class LRUCacheDataset(BaseWrapperDataset):
12 | def __init__(self, dataset, token=None):
13 | super().__init__(dataset)
14 |
15 | @lru_cache(maxsize=8)
16 | def __getitem__(self, index):
17 | return self.dataset[index]
18 |
19 | @lru_cache(maxsize=8)
20 | def collater(self, samples):
21 | return self.dataset.collater(samples)
22 |
--------------------------------------------------------------------------------
/fairseq/data/multilingual/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
--------------------------------------------------------------------------------
/fairseq/data/num_samples_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import FairseqDataset
7 |
8 |
9 | class NumSamplesDataset(FairseqDataset):
10 | def __getitem__(self, index):
11 | return 1
12 |
13 | def __len__(self):
14 | return 0
15 |
16 | def collater(self, samples):
17 | return sum(samples)
18 |
--------------------------------------------------------------------------------
/fairseq/data/numel_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class NumelDataset(BaseWrapperDataset):
13 | def __init__(self, dataset, reduce=False):
14 | super().__init__(dataset)
15 | self.reduce = reduce
16 |
17 | def __getitem__(self, index):
18 | item = self.dataset[index]
19 | if torch.is_tensor(item):
20 | return torch.numel(item)
21 | else:
22 | return np.size(item)
23 |
24 | def __len__(self):
25 | return len(self.dataset)
26 |
27 | def collater(self, samples):
28 | if self.reduce:
29 | return sum(samples)
30 | else:
31 | return torch.tensor(samples)
32 |
--------------------------------------------------------------------------------
/fairseq/data/offset_tokens_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class OffsetTokensDataset(BaseWrapperDataset):
10 | def __init__(self, dataset, offset):
11 | super().__init__(dataset)
12 | self.offset = offset
13 |
14 | def __getitem__(self, idx):
15 | return self.dataset[idx] + self.offset
16 |
--------------------------------------------------------------------------------
/fairseq/data/pad_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from fairseq.data import data_utils
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class PadDataset(BaseWrapperDataset):
12 | def __init__(self, dataset, pad_idx, left_pad, pad_length=None):
13 | super().__init__(dataset)
14 | self.pad_idx = pad_idx
15 | self.left_pad = left_pad
16 | self.pad_length = pad_length
17 |
18 | def collater(self, samples):
19 | return data_utils.collate_tokens(
20 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length
21 | )
22 |
23 |
24 | class LeftPadDataset(PadDataset):
25 | def __init__(self, dataset, pad_idx):
26 | super().__init__(dataset, pad_idx, left_pad=True)
27 |
28 |
29 | class RightPadDataset(PadDataset):
30 | def __init__(self, dataset, pad_idx):
31 | super().__init__(dataset, pad_idx, left_pad=False)
32 |
--------------------------------------------------------------------------------
/fairseq/data/prepend_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class PrependDataset(BaseWrapperDataset):
13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
14 | super().__init__(dataset)
15 | self.prepend_getter = prepend_getter
16 | self.ensure_first_token = ensure_first_token_is
17 |
18 | def __getitem__(self, idx):
19 | item = self.dataset[idx]
20 | is_tuple = isinstance(item, tuple)
21 | src = item[0] if is_tuple else item
22 |
23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token
24 | prepend_idx = self.prepend_getter(self.dataset, idx)
25 | assert isinstance(prepend_idx, int)
26 | src[0] = prepend_idx
27 | item = tuple((src,) + item[1:]) if is_tuple else src
28 | return item
29 |
--------------------------------------------------------------------------------
/fairseq/data/prepend_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from . import BaseWrapperDataset
10 |
11 |
12 | class PrependTokenDataset(BaseWrapperDataset):
13 | def __init__(self, dataset, token=None):
14 | super().__init__(dataset)
15 | self.token = token
16 | if token is not None:
17 | self._sizes = np.array(dataset.sizes) + 1
18 | else:
19 | self._sizes = dataset.sizes
20 |
21 | def __getitem__(self, idx):
22 | item = self.dataset[idx]
23 | if self.token is not None:
24 | item = torch.cat([item.new([self.token]), item])
25 | return item
26 |
27 | @property
28 | def sizes(self):
29 | return self._sizes
30 |
31 | def num_tokens(self, index):
32 | n = self.dataset.num_tokens(index)
33 | if self.token is not None:
34 | n += 1
35 | return n
36 |
37 | def size(self, index):
38 | n = self.dataset.size(index)
39 | if self.token is not None:
40 | n += 1
41 | return n
42 |
--------------------------------------------------------------------------------
/fairseq/data/raw_label_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import FairseqDataset
9 |
10 |
11 | class RawLabelDataset(FairseqDataset):
12 | def __init__(self, labels):
13 | super().__init__()
14 | self.labels = labels
15 |
16 | def __getitem__(self, index):
17 | return self.labels[index]
18 |
19 | def __len__(self):
20 | return len(self.labels)
21 |
22 | def collater(self, samples):
23 | return torch.tensor(samples)
24 |
--------------------------------------------------------------------------------
/fairseq/data/replace_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class ReplaceDataset(BaseWrapperDataset):
10 | """Replaces tokens found in the dataset by a specified replacement token
11 |
12 | Args:
13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in
14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token
15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
16 | as many as the number of objects returned by the underlying dataset __getitem__ method.
17 | """
18 |
19 | def __init__(self, dataset, replace_map, offsets):
20 | super().__init__(dataset)
21 | assert len(replace_map) > 0
22 | self.replace_map = replace_map
23 | self.offsets = offsets
24 |
25 | def __getitem__(self, index):
26 | item = self.dataset[index]
27 | is_tuple = isinstance(item, tuple)
28 | srcs = item if is_tuple else [item]
29 |
30 | for offset, src in zip(self.offsets, srcs):
31 | for k, v in self.replace_map.items():
32 | src_off = src[offset:] if offset >= 0 else src[:offset]
33 | src_off.masked_fill_(src_off == k, v)
34 |
35 | item = srcs if is_tuple else srcs[0]
36 | return item
37 |
--------------------------------------------------------------------------------
/fairseq/data/roll_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class RollDataset(BaseWrapperDataset):
12 | def __init__(self, dataset, shifts):
13 | super().__init__(dataset)
14 | self.shifts = shifts
15 |
16 | def __getitem__(self, index):
17 | item = self.dataset[index]
18 | return torch.roll(item, self.shifts)
19 |
--------------------------------------------------------------------------------
/fairseq/data/sort_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 | from . import BaseWrapperDataset
9 |
10 |
11 | class SortDataset(BaseWrapperDataset):
12 | def __init__(self, dataset, sort_order):
13 | super().__init__(dataset)
14 | if not isinstance(sort_order, (list, tuple)):
15 | sort_order = [sort_order]
16 | self.sort_order = sort_order
17 |
18 | assert all(len(so) == len(dataset) for so in sort_order)
19 |
20 | def ordered_indices(self):
21 | return np.lexsort(self.sort_order)
22 |
--------------------------------------------------------------------------------
/fairseq/data/strip_token_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import BaseWrapperDataset
7 |
8 |
9 | class StripTokenDataset(BaseWrapperDataset):
10 | def __init__(self, dataset, id_to_strip):
11 | super().__init__(dataset)
12 | self.id_to_strip = id_to_strip
13 |
14 | def __getitem__(self, index):
15 | item = self.dataset[index]
16 | while len(item) > 0 and item[-1] == self.id_to_strip:
17 | item = item[:-1]
18 | while len(item) > 0 and item[0] == self.id_to_strip:
19 | item = item[1:]
20 | return item
21 |
--------------------------------------------------------------------------------
/fairseq/dataclass/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .configs import FairseqDataclass
7 | from .constants import ChoiceEnum
8 |
9 |
10 | __all__ = [
11 | "FairseqDataclass",
12 | "ChoiceEnum",
13 | ]
14 |
--------------------------------------------------------------------------------
/fairseq/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .distributed_timeout_wrapper import DistributedTimeoutWrapper
7 | from .fully_sharded_data_parallel import (
8 | fsdp_enable_wrap,
9 | fsdp_wrap,
10 | FullyShardedDataParallel,
11 | )
12 | from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
13 | from .module_proxy_wrapper import ModuleProxyWrapper
14 | from .tpu_distributed_data_parallel import TPUDistributedDataParallel
15 |
16 |
17 | __all__ = [
18 | "DistributedTimeoutWrapper",
19 | "fsdp_enable_wrap",
20 | "fsdp_wrap",
21 | "FullyShardedDataParallel",
22 | "LegacyDistributedDataParallel",
23 | "ModuleProxyWrapper",
24 | "TPUDistributedDataParallel",
25 | ]
26 |
--------------------------------------------------------------------------------
/fairseq/distributed/tpu_distributed_data_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from fairseq.distributed import utils
10 |
11 |
12 | class TPUDistributedDataParallel(nn.Module):
13 | def __init__(self, module, process_group):
14 | super().__init__()
15 | self.module = module
16 | self.process_group = process_group
17 | self.world_size = utils.get_world_size(self.process_group)
18 |
19 | def forward(self, *inputs, **kwargs):
20 | return self.module(*inputs, **kwargs)
21 |
22 | def all_reduce_grads(self):
23 | gradients = []
24 | for p in self.parameters():
25 | if not p.requires_grad:
26 | continue
27 | if p.grad is None:
28 | p.grad = torch.zeros_like(p)
29 | if p.grad.requires_grad:
30 | raise RuntimeError(
31 | "TPUDistributedDataParallel only works with gradients that don't "
32 | "require grad"
33 | )
34 | gradients.append(p.grad)
35 |
36 | import torch_xla.core.xla_model as xm
37 |
38 | xm.all_reduce(
39 | "sum",
40 | gradients,
41 | scale=1.0 / self.world_size,
42 | groups=self.process_group[1],
43 | )
44 |
--------------------------------------------------------------------------------
/fairseq/logging/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/fairseq/logging/__init__.py
--------------------------------------------------------------------------------
/fairseq/model_parallel/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import criterions, models, modules # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/model_parallel/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 |
10 | # automatically import any Python files in the criterions/ directory
11 | for file in sorted(os.listdir(os.path.dirname(__file__))):
12 | if file.endswith(".py") and not file.startswith("_"):
13 | module = file[: file.find(".py")]
14 | importlib.import_module("fairseq.model_parallel.criterions." + module)
15 |
--------------------------------------------------------------------------------
/fairseq/model_parallel/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 |
10 | # automatically import any Python files in the models/ directory
11 | models_dir = os.path.dirname(__file__)
12 | for file in os.listdir(models_dir):
13 | path = os.path.join(models_dir, file)
14 | if (
15 | not file.startswith("_")
16 | and not file.startswith(".")
17 | and (file.endswith(".py") or os.path.isdir(path))
18 | ):
19 | model_name = file[: file.find(".py")] if file.endswith(".py") else file
20 | module = importlib.import_module("fairseq.model_parallel.models." + model_name)
21 |
--------------------------------------------------------------------------------
/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .model import * # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/model_parallel/models/roberta/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .model import * # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/model_parallel/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | from .multihead_attention import ModelParallelMultiheadAttention
8 | from .transformer_layer import (
9 | ModelParallelTransformerEncoderLayer,
10 | ModelParallelTransformerDecoderLayer,
11 | )
12 |
13 | __all__ = [
14 | "ModelParallelMultiheadAttention",
15 | "ModelParallelTransformerEncoderLayer",
16 | "ModelParallelTransformerDecoderLayer",
17 | ]
18 |
--------------------------------------------------------------------------------
/fairseq/models/bart/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .hub_interface import * # noqa
7 | from .model import * # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/models/ema/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 | from .ema import EMA
10 |
11 |
12 | def build_ema(model, cfg, device):
13 | return EMA(model, cfg, device)
14 |
15 |
16 | # automatically import any Python files in the models/ema/ directory
17 | for file in sorted(os.listdir(os.path.dirname(__file__))):
18 | if file.endswith(".py") and not file.startswith("_"):
19 | file_name = file[: file.find(".py")]
20 | importlib.import_module("fairseq.models.ema." + file_name)
21 |
--------------------------------------------------------------------------------
/fairseq/models/hubert/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .hubert import * # noqa
7 | from .hubert_asr import * # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/models/huggingface/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import os
8 |
9 |
10 | # automatically import any Python files in the models/huggingface/ directory
11 | models_dir = os.path.dirname(__file__)
12 | for file in os.listdir(models_dir):
13 | path = os.path.join(models_dir, file)
14 | if (
15 | not file.startswith("_")
16 | and not file.startswith(".")
17 | and (file.endswith(".py") or os.path.isdir(path))
18 | ):
19 | model_name = file[: file.find(".py")] if file.endswith(".py") else file
20 | module = importlib.import_module("fairseq.models.huggingface." + model_name)
21 |
--------------------------------------------------------------------------------
/fairseq/models/nat/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | from .fairseq_nat_model import *
8 | from .nonautoregressive_transformer import *
9 | from .nat_crf_transformer import *
10 | from .iterative_nonautoregressive_transformer import *
11 | from .cmlm_transformer import *
12 | from .levenshtein_transformer import *
13 | from .insertion_transformer import *
14 |
--------------------------------------------------------------------------------
/fairseq/models/roberta/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .hub_interface import * # noqa
7 | from .model import * # noqa
8 | from .enc_dec import * # noqa
9 | from .model_camembert import * # noqa
10 | from .model_gottbert import * # noqa
11 | from .model_xlmr import * # noqa
12 |
--------------------------------------------------------------------------------
/fairseq/models/roberta/model_gottbert.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | GottBERT: a pure German Language Model
7 | """
8 |
9 | from fairseq.models import register_model
10 |
11 | from .hub_interface import RobertaHubInterface
12 | from .model import RobertaModel
13 |
14 |
15 | @register_model("gottbert")
16 | class GottbertModel(RobertaModel):
17 | @classmethod
18 | def hub_models(cls):
19 | return {
20 | "gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz",
21 | }
22 |
23 | @classmethod
24 | def from_pretrained(
25 | cls,
26 | model_name_or_path,
27 | checkpoint_file="model.pt",
28 | data_name_or_path=".",
29 | bpe="hf_byte_bpe",
30 | bpe_vocab="vocab.json",
31 | bpe_merges="merges.txt",
32 | bpe_add_prefix_space=False,
33 | **kwargs
34 | ):
35 | from fairseq import hub_utils
36 |
37 | x = hub_utils.from_pretrained(
38 | model_name_or_path,
39 | checkpoint_file,
40 | data_name_or_path,
41 | archive_map=cls.hub_models(),
42 | bpe=bpe,
43 | load_checkpoint_heads=True,
44 | bpe_vocab=bpe_vocab,
45 | bpe_merges=bpe_merges,
46 | bpe_add_prefix_space=bpe_add_prefix_space,
47 | **kwargs,
48 | )
49 | return RobertaHubInterface(x["args"], x["task"], x["models"][0])
50 |
--------------------------------------------------------------------------------
/fairseq/models/roberta/model_xlmr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | Unsupervised Cross-lingual Representation Learning at Scale
7 | """
8 |
9 | from fairseq.models import register_model
10 |
11 | from .hub_interface import RobertaHubInterface
12 | from .model import RobertaModel
13 |
14 |
15 | @register_model("xlmr")
16 | class XLMRModel(RobertaModel):
17 | @classmethod
18 | def hub_models(cls):
19 | return {
20 | "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz",
21 | "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz",
22 | "xlmr.xl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xl.tar.gz",
23 | "xlmr.xxl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xxl.tar.gz",
24 | }
25 |
26 | @classmethod
27 | def from_pretrained(
28 | cls,
29 | model_name_or_path,
30 | checkpoint_file="model.pt",
31 | data_name_or_path=".",
32 | bpe="sentencepiece",
33 | **kwargs
34 | ):
35 | from fairseq import hub_utils
36 |
37 | x = hub_utils.from_pretrained(
38 | model_name_or_path,
39 | checkpoint_file,
40 | data_name_or_path,
41 | archive_map=cls.hub_models(),
42 | bpe=bpe,
43 | load_checkpoint_heads=True,
44 | **kwargs,
45 | )
46 | return RobertaHubInterface(x["args"], x["task"], x["models"][0])
47 |
--------------------------------------------------------------------------------
/fairseq/models/speech_to_speech/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .modules import * # noqa
7 | from .s2s_transformer import * # noqa
8 | from .s2s_conformer import * # noqa
9 |
--------------------------------------------------------------------------------
/fairseq/models/speech_to_text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .berard import * # noqa
7 | from .convtransformer import * # noqa
8 | from .multi_modality_model import * # noqa
9 | from .s2t_transformer import * # noqa
10 | from .s2t_wav_transformer import * # noqa
11 | from .xm_transformer import * # noqa
12 | from .s2t_conformer import * # noqa
13 |
--------------------------------------------------------------------------------
/fairseq/models/text_to_speech/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .tacotron2 import * # noqa
7 | from .tts_transformer import * # noqa
8 | from .fastspeech2 import * # noqa
9 |
--------------------------------------------------------------------------------
/fairseq/models/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | from .transformer_config import (
8 | TransformerConfig,
9 | DEFAULT_MAX_SOURCE_POSITIONS,
10 | DEFAULT_MAX_TARGET_POSITIONS,
11 | DEFAULT_MIN_PARAMS_TO_WRAP,
12 | )
13 | from .transformer_decoder import TransformerDecoder, TransformerDecoderBase, Linear
14 | from .transformer_encoder import TransformerEncoder, TransformerEncoderBase
15 | from .transformer_legacy import (
16 | TransformerModel,
17 | base_architecture,
18 | tiny_architecture,
19 | transformer_iwslt_de_en,
20 | transformer_wmt_en_de,
21 | transformer_vaswani_wmt_en_de_big,
22 | transformer_vaswani_wmt_en_fr_big,
23 | transformer_wmt_en_de_big,
24 | transformer_wmt_en_de_big_t2t,
25 | )
26 | from .transformer_base import TransformerModelBase, Embedding
27 |
28 |
29 | __all__ = [
30 | "TransformerModelBase",
31 | "TransformerConfig",
32 | "TransformerDecoder",
33 | "TransformerDecoderBase",
34 | "TransformerEncoder",
35 | "TransformerEncoderBase",
36 | "TransformerModel",
37 | "Embedding",
38 | "Linear",
39 | "base_architecture",
40 | "tiny_architecture",
41 | "transformer_iwslt_de_en",
42 | "transformer_wmt_en_de",
43 | "transformer_vaswani_wmt_en_de_big",
44 | "transformer_vaswani_wmt_en_fr_big",
45 | "transformer_wmt_en_de_big",
46 | "transformer_wmt_en_de_big_t2t",
47 | "DEFAULT_MAX_SOURCE_POSITIONS",
48 | "DEFAULT_MAX_TARGET_POSITIONS",
49 | "DEFAULT_MIN_PARAMS_TO_WRAP",
50 | ]
51 |
--------------------------------------------------------------------------------
/fairseq/models/wav2vec/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .wav2vec import * # noqa
7 | from .wav2vec2 import * # noqa
8 | from .wav2vec2_asr import * # noqa
9 |
--------------------------------------------------------------------------------
/fairseq/models/wav2vec/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | import torch.nn.functional as F
8 |
9 |
10 | def pad_to_multiple(x, multiple, dim=-1, value=0):
11 | # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
12 | if x is None:
13 | return None, 0
14 | tsz = x.size(dim)
15 | m = tsz / multiple
16 | remainder = math.ceil(m) * multiple - tsz
17 | if m.is_integer():
18 | return x, 0
19 | pad_offset = (0,) * (-1 - dim) * 2
20 |
21 | return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
22 |
--------------------------------------------------------------------------------
/fairseq/models/xmod/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .model import * # noqa
7 | from .transformer_layer_xmod import * # noqa
8 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .dynamicconv_layer import DynamicconvLayer # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | std::vector
12 | dynamicconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l);
13 |
14 | std::vector dynamicconv_cuda_backward(
15 | at::Tensor gradOutput,
16 | int padding_l,
17 | at::Tensor input,
18 | at::Tensor filters);
19 |
20 | #define CHECK_CUDA(x) \
21 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
22 | #define CHECK_CONTIGUOUS(x) \
23 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
24 | #define CHECK_INPUT(x) \
25 | CHECK_CUDA(x); \
26 | CHECK_CONTIGUOUS(x)
27 |
28 | std::vector
29 | dynamicconv_forward(at::Tensor input, at::Tensor filters, int padding_l) {
30 | CHECK_INPUT(input);
31 | CHECK_INPUT(filters);
32 |
33 | return dynamicconv_cuda_forward(input, filters, padding_l);
34 | }
35 |
36 | std::vector dynamicconv_backward(
37 | at::Tensor gradOutput,
38 | int padding_l,
39 | at::Tensor input,
40 | at::Tensor filters) {
41 | CHECK_INPUT(gradOutput);
42 | CHECK_INPUT(input);
43 | CHECK_INPUT(filters);
44 |
45 | return dynamicconv_cuda_backward(gradOutput, padding_l, input, filters);
46 | }
47 |
48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
49 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)");
50 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)");
51 | }
52 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #include
23 | #include
24 | #include
25 |
26 | #define SHFL_MASK 0xffffffff
27 |
28 | template
29 | __global__ void dynamicconv_forward_kernel(
30 | const scalar_t* input,
31 | const scalar_t* weight,
32 | int minibatch,
33 | int sequenceLength,
34 | int numFeatures,
35 | int numFiltersInBlock,
36 | int numHeads,
37 | scalar_t* output);
38 |
39 | template
40 | __global__ void dynamicconv_backward_kernel(
41 | const scalar_t* gradOutput, // B * C * T
42 | const scalar_t* input, // B * C * T
43 | const scalar_t* weight,
44 | int minibatch,
45 | int sequenceLength,
46 | int numFeatures,
47 | int numFiltersInBlock,
48 | int numHeads,
49 | scalar_t* gradWeight,
50 | scalar_t* gradInput); // B * H * k * T
51 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | std::vector
5 | dynamicconv_cpu_forward(float* input, float* filters, int padding_l);
6 |
7 | std::vector dynamicconv_cpu_backward(
8 | float* gradOutput,
9 | int padding_l,
10 | float* input,
11 | float* filters);
12 |
13 | std::vector
14 | dynamicconv_forward(float* input, float* filters, int padding_l) {
15 | return dynamicconv_cpu_forward(input, filters, padding_l);
16 | }
17 |
18 | std::vector dynamicconv_backward(
19 | float* gradOutput,
20 | int padding_l,
21 | float* input,
22 | float* filters) {
23 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters);
24 | }
25 |
26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
27 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)");
28 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)");
29 | }
30 |
--------------------------------------------------------------------------------
/fairseq/modules/dynamicconv_layer/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
9 |
10 |
11 | setup(
12 | name="dynamicconv_layer",
13 | ext_modules=[
14 | CUDAExtension(
15 | name="dynamicconv_cuda",
16 | sources=[
17 | "dynamicconv_cuda.cpp",
18 | "dynamicconv_cuda_kernel.cu",
19 | ],
20 | ),
21 | ],
22 | cmdclass={"build_ext": BuildExtension},
23 | )
24 |
--------------------------------------------------------------------------------
/fairseq/modules/fp32_batch_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | batch norm done in fp32 (for fp16 training)
7 | """
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class Fp32BatchNorm(nn.Module):
13 | def __init__(self, sync=False, *args, **kwargs):
14 | super().__init__()
15 |
16 | if sync:
17 | from fairseq.distributed import utils
18 |
19 | if utils.get_global_world_size() == 1:
20 | sync = False
21 |
22 | if sync:
23 | self.bn = nn.SyncBatchNorm(*args, **kwargs)
24 | else:
25 | self.bn = nn.BatchNorm1d(*args, **kwargs)
26 |
27 | self.sync = sync
28 |
29 | def forward(self, input):
30 | if self.bn.running_mean.dtype != torch.float:
31 | if self.sync:
32 | self.bn.running_mean = self.bn.running_mean.float()
33 | self.bn.running_var = self.bn.running_var.float()
34 | if self.bn.affine:
35 | try:
36 | self.bn.weight = self.bn.weight.float()
37 | self.bn.bias = self.bn.bias.float()
38 | except:
39 | self.bn.float()
40 | else:
41 | self.bn.float()
42 |
43 | output = self.bn(input.float())
44 | return output.type_as(input)
45 |
--------------------------------------------------------------------------------
/fairseq/modules/fp32_group_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | Layer norm done in fp32 (for fp16 training)
7 | """
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class Fp32GroupNorm(nn.GroupNorm):
14 | def __init__(self, *args, **kwargs):
15 | super().__init__(*args, **kwargs)
16 |
17 | def forward(self, input):
18 | output = F.group_norm(
19 | input.float(),
20 | self.num_groups,
21 | self.weight.float() if self.weight is not None else None,
22 | self.bias.float() if self.bias is not None else None,
23 | self.eps,
24 | )
25 | return output.type_as(input)
26 |
--------------------------------------------------------------------------------
/fairseq/modules/fp32_instance_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | Layer norm done in fp32 (for fp16 training)
7 | """
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class Fp32InstanceNorm(nn.InstanceNorm1d):
14 | def __init__(self, *args, **kwargs):
15 | self.transpose_last = "transpose_last" in kwargs and kwargs["transpose_last"]
16 | if "transpose_last" in kwargs:
17 | del kwargs["transpose_last"]
18 | super().__init__(*args, **kwargs)
19 |
20 | def forward(self, input):
21 | if self.transpose_last:
22 | input = input.transpose(1, 2)
23 | output = F.instance_norm(
24 | input.float(),
25 | running_mean=self.running_mean,
26 | running_var=self.running_var,
27 | weight=self.weight.float() if self.weight is not None else None,
28 | bias=self.bias.float() if self.bias is not None else None,
29 | use_input_stats=self.training or not self.track_running_stats,
30 | momentum=self.momentum,
31 | eps=self.eps,
32 | )
33 | if self.transpose_last:
34 | output = output.transpose(1, 2)
35 | return output.type_as(input)
36 |
--------------------------------------------------------------------------------
/fairseq/modules/gelu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs
8 | """
9 |
10 | import math
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 |
16 | def gelu_accurate(x):
17 | if not hasattr(gelu_accurate, "_a"):
18 | gelu_accurate._a = math.sqrt(2 / math.pi)
19 | return (
20 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
21 | )
22 |
23 |
24 | def gelu(x: torch.Tensor) -> torch.Tensor:
25 | return torch.nn.functional.gelu(x.float()).type_as(x)
26 |
--------------------------------------------------------------------------------
/fairseq/modules/grad_multiply.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | class GradMultiply(torch.autograd.Function):
10 | @staticmethod
11 | def forward(ctx, x, scale):
12 | ctx.scale = scale
13 | res = x.new(x)
14 | return res
15 |
16 | @staticmethod
17 | def backward(ctx, grad):
18 | return grad * ctx.scale, None
19 |
--------------------------------------------------------------------------------
/fairseq/modules/layer_drop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | LayerDrop as described in https://arxiv.org/abs/1909.11556.
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class LayerDropModuleList(nn.ModuleList):
14 | """
15 | A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
16 |
17 | We refresh the choice of which layers to drop every time we iterate
18 | over the LayerDropModuleList instance. During evaluation we always
19 | iterate over all layers.
20 |
21 | Usage::
22 |
23 | layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
24 | for layer in layers: # this might iterate over layers 1 and 3
25 | x = layer(x)
26 | for layer in layers: # this might iterate over all layers
27 | x = layer(x)
28 | for layer in layers: # this might not iterate over any layers
29 | x = layer(x)
30 |
31 | Args:
32 | p (float): probability of dropping out each layer
33 | modules (iterable, optional): an iterable of modules to add
34 | """
35 |
36 | def __init__(self, p, modules=None):
37 | super().__init__(modules)
38 | self.p = p
39 |
40 | def __iter__(self):
41 | dropout_probs = torch.empty(len(self)).uniform_()
42 | for i, m in enumerate(super().__iter__()):
43 | if not self.training or (dropout_probs[i] > self.p):
44 | yield m
45 |
--------------------------------------------------------------------------------
/fairseq/modules/layer_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | try:
11 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm
12 |
13 | has_fused_layernorm = True
14 |
15 | class FusedLayerNorm(_FusedLayerNorm):
16 | @torch.jit.unused
17 | def forward(self, x):
18 | if not x.is_cuda:
19 | return super().forward(x)
20 | else:
21 | with torch.cuda.device(x.device):
22 | return super().forward(x)
23 |
24 | except ImportError:
25 | has_fused_layernorm = False
26 |
27 |
28 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
29 | if torch.jit.is_scripting() or torch.jit.is_tracing():
30 | export = True
31 | if not export and torch.cuda.is_available() and has_fused_layernorm:
32 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
33 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
34 |
35 |
36 | class Fp32LayerNorm(nn.LayerNorm):
37 | def __init__(self, *args, **kwargs):
38 | super().__init__(*args, **kwargs)
39 |
40 | def forward(self, input):
41 | output = F.layer_norm(
42 | input.float(),
43 | self.normalized_shape,
44 | self.weight.float() if self.weight is not None else None,
45 | self.bias.float() if self.bias is not None else None,
46 | self.eps,
47 | )
48 | return output.type_as(input)
49 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .lightconv_layer import LightconvLayer # noqa
7 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/lightconv_cuda.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include
9 | #include
10 |
11 | std::vector
12 | lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l);
13 |
14 | std::vector lightconv_cuda_backward(
15 | at::Tensor gradOutput,
16 | int padding_l,
17 | at::Tensor input,
18 | at::Tensor filters);
19 |
20 | #define CHECK_CUDA(x) \
21 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
22 | #define CHECK_CONTIGUOUS(x) \
23 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
24 | #define CHECK_INPUT(x) \
25 | CHECK_CUDA(x); \
26 | CHECK_CONTIGUOUS(x)
27 |
28 | std::vector
29 | lightconv_forward(at::Tensor input, at::Tensor filters, int padding_l) {
30 | CHECK_INPUT(input);
31 | CHECK_INPUT(filters);
32 |
33 | return lightconv_cuda_forward(input, filters, padding_l);
34 | }
35 |
36 | std::vector lightconv_backward(
37 | at::Tensor gradOutput,
38 | int padding_l,
39 | at::Tensor input,
40 | at::Tensor filters) {
41 | CHECK_INPUT(gradOutput);
42 | CHECK_INPUT(input);
43 | CHECK_INPUT(filters);
44 |
45 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters);
46 | }
47 |
48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
49 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)");
50 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)");
51 | }
52 |
--------------------------------------------------------------------------------
/fairseq/modules/lightconv_layer/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
9 |
10 |
11 | setup(
12 | name="lightconv_layer",
13 | ext_modules=[
14 | CUDAExtension(
15 | "lightconv_cuda",
16 | [
17 | "lightconv_cuda.cpp",
18 | "lightconv_cuda_kernel.cu",
19 | ],
20 | ),
21 | ],
22 | cmdclass={"build_ext": BuildExtension},
23 | )
24 |
--------------------------------------------------------------------------------
/fairseq/modules/lstm_cell_with_zoneout.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 |
9 | class LSTMCellWithZoneOut(nn.Module):
10 | """
11 | Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations
12 | https://arxiv.org/abs/1606.01305
13 | """
14 |
15 | def __init__(
16 | self, prob: float, input_size: int, hidden_size: int, bias: bool = True
17 | ):
18 | super(LSTMCellWithZoneOut, self).__init__()
19 | self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
20 | self.prob = prob
21 | if prob > 1.0 or prob < 0.0:
22 | raise ValueError(
23 | "zoneout probability must be in the range from " "0.0 to 1.0."
24 | )
25 |
26 | def zoneout(self, h, next_h, prob):
27 | if isinstance(h, tuple):
28 | return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))])
29 |
30 | if self.training:
31 | mask = h.new_zeros(*h.size()).bernoulli_(prob)
32 | return mask * h + (1 - mask) * next_h
33 |
34 | return prob * h + (1 - prob) * next_h
35 |
36 | def forward(self, x, h):
37 | return self.zoneout(h, self.lstm_cell(x, h), self.prob)
38 |
--------------------------------------------------------------------------------
/fairseq/modules/positional_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 |
8 | from .learned_positional_embedding import LearnedPositionalEmbedding
9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
10 |
11 |
12 | def PositionalEmbedding(
13 | num_embeddings: int,
14 | embedding_dim: int,
15 | padding_idx: int,
16 | learned: bool = False,
17 | ):
18 | if learned:
19 | # if padding_idx is specified then offset the embedding ids by
20 | # this index and adjust num_embeddings appropriately
21 | # TODO: The right place for this offset would be inside
22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
23 | if padding_idx is not None:
24 | num_embeddings = num_embeddings + padding_idx + 1
25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
27 | if padding_idx is not None:
28 | nn.init.constant_(m.weight[padding_idx], 0)
29 | else:
30 | m = SinusoidalPositionalEmbedding(
31 | embedding_dim,
32 | padding_idx,
33 | init_size=num_embeddings + padding_idx + 1,
34 | )
35 | return m
36 |
--------------------------------------------------------------------------------
/fairseq/modules/quantization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/fairseq/modules/quantization/__init__.py
--------------------------------------------------------------------------------
/fairseq/modules/quantization/pq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .utils import SizeTracker, get_param, attrsetter, quantize_model_ # NOQA
7 |
--------------------------------------------------------------------------------
/fairseq/modules/quantization/pq/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .qconv import PQConv2d # NOQA
7 | from .qemb import PQEmbedding # NOQA
8 | from .qlinear import PQLinear # NOQA
9 |
--------------------------------------------------------------------------------
/fairseq/modules/quantization/scalar/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .utils import quantize_model_ # NOQA
7 |
--------------------------------------------------------------------------------
/fairseq/modules/quantization/scalar/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from .qact import ActivationQuantizer # NOQA
7 | from .qconv import IntConv2d # NOQA
8 | from .qemb import IntEmbedding # NOQA
9 | from .qlinear import IntLinear # NOQA
10 |
--------------------------------------------------------------------------------
/fairseq/modules/same_pad.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from torch import nn
8 |
9 |
10 | class SamePad(nn.Module):
11 | def __init__(self, kernel_size, causal=False):
12 | super().__init__()
13 | if causal:
14 | self.remove = kernel_size - 1
15 | else:
16 | self.remove = 1 if kernel_size % 2 == 0 else 0
17 |
18 | def forward(self, x):
19 | if self.remove > 0:
20 | x = x[:, :, : -self.remove]
21 | return x
22 |
--------------------------------------------------------------------------------
/fairseq/modules/scalar_bias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | #
6 |
7 | import torch
8 |
9 |
10 | class ScalarBias(torch.autograd.Function):
11 | """
12 | Adds a vector of scalars, used in self-attention mechanism to allow
13 | the model to optionally attend to this vector instead of the past
14 | """
15 |
16 | @staticmethod
17 | def forward(ctx, input, dim, bias_init):
18 | size = list(input.size())
19 | size[dim] += 1
20 | output = input.new(*size).fill_(bias_init)
21 | output.narrow(dim, 1, size[dim] - 1).copy_(input)
22 | ctx.dim = dim
23 | return output
24 |
25 | @staticmethod
26 | def backward(ctx, grad):
27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
28 |
29 |
30 | def scalar_bias(input, dim, bias_init=0):
31 | return ScalarBias.apply(input, dim, bias_init)
32 |
--------------------------------------------------------------------------------
/fairseq/modules/transpose_last.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | transpose last 2 dimensions of the input
7 | """
8 |
9 | import torch.nn as nn
10 |
11 |
12 | class TransposeLast(nn.Module):
13 | def __init__(self, deconstruct_idx=None):
14 | super().__init__()
15 | self.deconstruct_idx = deconstruct_idx
16 |
17 | def forward(self, x):
18 | if self.deconstruct_idx is not None:
19 | x = x[self.deconstruct_idx]
20 | return x.transpose(-2, -1)
21 |
--------------------------------------------------------------------------------
/fairseq/modules/unfold.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn.functional as F
7 |
8 |
9 | def unfold1d(x, kernel_size, padding_l, pad_value=0):
10 | """unfold T x B x C to T x B x C x K"""
11 | if kernel_size > 1:
12 | T, B, C = x.size()
13 | x = F.pad(
14 | x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value
15 | )
16 | x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C))
17 | else:
18 | x = x.unsqueeze(3)
19 | return x
20 |
--------------------------------------------------------------------------------
/fairseq/optim/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | import importlib
8 | import os
9 |
10 | from fairseq import registry
11 | from fairseq.optim.bmuf import FairseqBMUF # noqa
12 | from fairseq.optim.fairseq_optimizer import ( # noqa
13 | FairseqOptimizer,
14 | LegacyFairseqOptimizer,
15 | )
16 | from fairseq.optim.amp_optimizer import AMPOptimizer
17 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
18 | from fairseq.optim.shard import shard_
19 | from omegaconf import DictConfig
20 |
21 | __all__ = [
22 | "AMPOptimizer",
23 | "FairseqOptimizer",
24 | "FP16Optimizer",
25 | "MemoryEfficientFP16Optimizer",
26 | "shard_",
27 | ]
28 |
29 | (
30 | _build_optimizer,
31 | register_optimizer,
32 | OPTIMIZER_REGISTRY,
33 | OPTIMIZER_DATACLASS_REGISTRY,
34 | ) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True)
35 |
36 |
37 | def build_optimizer(cfg: DictConfig, params, *extra_args, **extra_kwargs):
38 | if all(isinstance(p, dict) for p in params):
39 | params = [t for p in params for t in p.values()]
40 | params = list(filter(lambda p: p.requires_grad, params))
41 | return _build_optimizer(cfg, params, *extra_args, **extra_kwargs)
42 |
43 |
44 | # automatically import any Python files in the optim/ directory
45 | for file in sorted(os.listdir(os.path.dirname(__file__))):
46 | if file.endswith(".py") and not file.startswith("_"):
47 | file_name = file[: file.find(".py")]
48 | importlib.import_module("fairseq.optim." + file_name)
49 |
--------------------------------------------------------------------------------
/fairseq/optim/adagrad.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import LegacyFairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer("adagrad")
12 | class Adagrad(LegacyFairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22 | help='weight decay')
23 | # fmt: on
24 |
25 | @property
26 | def optimizer_config(self):
27 | """
28 | Return a kwarg dictionary that will be used to override optimizer
29 | args stored in checkpoints. This allows us to load a checkpoint and
30 | resume training using a different set of optimizer args, e.g., with a
31 | different learning rate.
32 | """
33 | return {
34 | "lr": self.args.lr[0],
35 | "weight_decay": self.args.weight_decay,
36 | }
37 |
38 | @property
39 | def supports_flat_params(self):
40 | return False
41 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """isort:skip_file"""
6 |
7 | import importlib
8 | import os
9 |
10 | from fairseq import registry
11 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa
12 | FairseqLRScheduler,
13 | LegacyFairseqLRScheduler,
14 | )
15 | from omegaconf import DictConfig
16 |
17 |
18 | (
19 | build_lr_scheduler_,
20 | register_lr_scheduler,
21 | LR_SCHEDULER_REGISTRY,
22 | LR_SCHEDULER_DATACLASS_REGISTRY,
23 | ) = registry.setup_registry(
24 | "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed"
25 | )
26 |
27 |
28 | def build_lr_scheduler(cfg: DictConfig, optimizer):
29 | return build_lr_scheduler_(cfg, optimizer)
30 |
31 |
32 | # automatically import any Python files in the optim/lr_scheduler/ directory
33 | for file in sorted(os.listdir(os.path.dirname(__file__))):
34 | if file.endswith(".py") and not file.startswith("_"):
35 | file_name = file[: file.find(".py")]
36 | importlib.import_module("fairseq.optim.lr_scheduler." + file_name)
37 |
--------------------------------------------------------------------------------
/fairseq/optim/lr_scheduler/pass_through.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass
7 |
8 | from fairseq.dataclass import FairseqDataclass
9 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
10 |
11 |
12 | @dataclass
13 | class PassThroughScheduleConfig(FairseqDataclass):
14 | pass
15 |
16 |
17 | @register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig)
18 | class PassThroughScheduleSchedule(FairseqLRScheduler):
19 | """Delegate lr scheduling to the optimizer."""
20 |
21 | def __init__(self, cfg: PassThroughScheduleConfig, optimizer):
22 | super().__init__(cfg, optimizer)
23 | assert (
24 | hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None
25 | ), "Pass-through schedule can only be used with optimizers with their own schedulers"
26 |
27 | def state_dict(self):
28 | return self.optimizer.lr_scheduler.state_dict()
29 |
30 | def load_state_dict(self, state_dict):
31 | self.optimizer.lr_scheduler.load_state_dict(state_dict)
32 |
33 | def step_begin_epoch(self, epoch):
34 | """Update the learning rate at the beginning of the given epoch."""
35 | return self.optimizer.lr_scheduler.step_begin_epoch(epoch)
36 |
37 | def step_update(self, num_updates):
38 | """Update the learning rate after each update."""
39 | return self.optimizer.lr_scheduler.step_update(num_updates)
40 |
--------------------------------------------------------------------------------
/fairseq/optim/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.optim
7 |
8 | from . import LegacyFairseqOptimizer, register_optimizer
9 |
10 |
11 | @register_optimizer("sgd")
12 | class SGD(LegacyFairseqOptimizer):
13 | def __init__(self, args, params):
14 | super().__init__(args)
15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
16 |
17 | @staticmethod
18 | def add_args(parser):
19 | """Add optimizer-specific arguments to the parser."""
20 | # fmt: off
21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
22 | help='momentum factor')
23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
24 | help='weight decay')
25 | # fmt: on
26 |
27 | @property
28 | def optimizer_config(self):
29 | """
30 | Return a kwarg dictionary that will be used to override optimizer
31 | args stored in checkpoints. This allows us to load a checkpoint and
32 | resume training using a different set of optimizer args, e.g., with a
33 | different learning rate.
34 | """
35 | return {
36 | "lr": self.args.lr[0],
37 | "momentum": self.args.momentum,
38 | "weight_decay": self.args.weight_decay,
39 | }
40 |
41 | @property
42 | def supports_flat_params(self):
43 | return True
44 |
--------------------------------------------------------------------------------
/fairseq/pdb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import multiprocessing
7 | import os
8 | import pdb
9 | import sys
10 |
11 |
12 | __all__ = ["set_trace"]
13 |
14 |
15 | _stdin = [None]
16 | _stdin_lock = multiprocessing.Lock()
17 | try:
18 | _stdin_fd = sys.stdin.fileno()
19 | except Exception:
20 | _stdin_fd = None
21 |
22 |
23 | class MultiprocessingPdb(pdb.Pdb):
24 | """A Pdb wrapper that works in a multiprocessing environment.
25 |
26 | Usage: `from fairseq import pdb; pdb.set_trace()`
27 | """
28 |
29 | def __init__(self):
30 | pdb.Pdb.__init__(self, nosigint=True)
31 |
32 | def _cmdloop(self):
33 | stdin_bak = sys.stdin
34 | with _stdin_lock:
35 | try:
36 | if _stdin_fd is not None:
37 | if not _stdin[0]:
38 | _stdin[0] = os.fdopen(_stdin_fd)
39 | sys.stdin = _stdin[0]
40 | self.cmdloop()
41 | finally:
42 | sys.stdin = stdin_bak
43 |
44 |
45 | def set_trace():
46 | pdb = MultiprocessingPdb()
47 | pdb.set_trace(sys._getframe().f_back)
48 |
--------------------------------------------------------------------------------
/fairseq/scoring/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | import importlib
8 | import os
9 | from abc import ABC, abstractmethod
10 |
11 | from fairseq import registry
12 | from omegaconf import DictConfig
13 |
14 |
15 | class BaseScorer(ABC):
16 | def __init__(self, cfg):
17 | self.cfg = cfg
18 | self.ref = []
19 | self.pred = []
20 |
21 | def add_string(self, ref, pred):
22 | self.ref.append(ref)
23 | self.pred.append(pred)
24 |
25 | @abstractmethod
26 | def score(self) -> float:
27 | pass
28 |
29 | @abstractmethod
30 | def result_string(self) -> str:
31 | pass
32 |
33 |
34 | _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
35 | "--scoring", default="bleu"
36 | )
37 |
38 |
39 | def build_scorer(choice, tgt_dict):
40 | _choice = choice._name if isinstance(choice, DictConfig) else choice
41 |
42 | if _choice == "bleu":
43 | from fairseq.scoring import bleu
44 |
45 | return bleu.Scorer(
46 | bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
47 | )
48 | return _build_scorer(choice)
49 |
50 |
51 | # automatically import any Python files in the current directory
52 | for file in sorted(os.listdir(os.path.dirname(__file__))):
53 | if file.endswith(".py") and not file.startswith("_"):
54 | module = file[: file.find(".py")]
55 | importlib.import_module("fairseq.scoring." + module)
56 |
--------------------------------------------------------------------------------
/fairseq/scoring/bertscore.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass, field
7 |
8 | import numpy as np
9 |
10 | from fairseq.dataclass import FairseqDataclass
11 | from fairseq.scoring import BaseScorer, register_scorer
12 |
13 |
14 | @dataclass
15 | class BertScoreScorerConfig(FairseqDataclass):
16 | bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"})
17 |
18 |
19 | @register_scorer("bert_score", dataclass=BertScoreScorerConfig)
20 | class BertScoreScorer(BaseScorer):
21 | def __init__(self, cfg):
22 | super(BertScoreScorer, self).__init__(cfg)
23 | try:
24 | import bert_score as _bert_score
25 | except ImportError:
26 | raise ImportError("Please install BERTScore: pip install bert-score")
27 |
28 | self.cfg = cfg
29 | self._bert_score = _bert_score
30 | self.scores = None
31 |
32 | def add_string(self, ref, pred):
33 | self.ref.append(ref)
34 | self.pred.append(pred)
35 |
36 | def score(self, order=4):
37 | _, _, self.scores = self._bert_score.score(
38 | self.pred, self.ref, lang=self.cfg.bert_score_lang
39 | )
40 | self.scores = self.scores.numpy()
41 | return np.mean(self.scores)
42 |
43 | def result_string(self, order=4):
44 | return f"BERTScore: {self.score():.4f}"
45 |
--------------------------------------------------------------------------------
/fairseq/scoring/chrf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | from dataclasses import dataclass
8 |
9 | from fairseq.dataclass import FairseqDataclass
10 | from fairseq.scoring import BaseScorer, register_scorer
11 |
12 |
13 | @dataclass
14 | class ChrFScorerConfig(FairseqDataclass):
15 | pass
16 |
17 |
18 | @register_scorer("chrf", dataclass=ChrFScorerConfig)
19 | class ChrFScorer(BaseScorer):
20 | def __init__(self, args):
21 | super(ChrFScorer, self).__init__(args)
22 | import sacrebleu
23 |
24 | self.sacrebleu = sacrebleu
25 |
26 | def add_string(self, ref, pred):
27 | self.ref.append(ref)
28 | self.pred.append(pred)
29 |
30 | def score(self, order=4):
31 | return self.result_string(order).score
32 |
33 | def result_string(self, order=4):
34 | if order != 4:
35 | raise NotImplementedError
36 | return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format()
37 |
--------------------------------------------------------------------------------
/fairseq/scoring/meteor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | from dataclasses import dataclass
8 |
9 | from fairseq.dataclass import FairseqDataclass
10 | from fairseq.scoring import BaseScorer, register_scorer
11 |
12 |
13 | @dataclass
14 | class MeteorScorerConfig(FairseqDataclass):
15 | pass
16 |
17 |
18 | @register_scorer("meteor", dataclass=MeteorScorerConfig)
19 | class MeteorScorer(BaseScorer):
20 | def __init__(self, args):
21 | super(MeteorScorer, self).__init__(args)
22 | try:
23 | import nltk
24 | except ImportError:
25 | raise ImportError("Please install nltk to use METEOR scorer")
26 |
27 | self.nltk = nltk
28 | self.scores = []
29 |
30 | def add_string(self, ref, pred):
31 | self.ref.append(ref)
32 | self.pred.append(pred)
33 |
34 | def score(self, order=4):
35 | self.scores = [
36 | self.nltk.translate.meteor_score.single_meteor_score(r, p)
37 | for r, p in zip(self.ref, self.pred)
38 | ]
39 | return np.mean(self.scores)
40 |
41 | def result_string(self, order=4):
42 | return f"METEOR: {self.score():.4f}"
43 |
--------------------------------------------------------------------------------
/fairseq/tasks/simultaneous_translation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import logging
7 | from fairseq.tasks import register_task
8 | from fairseq.tasks.speech_to_text import SpeechToTextTask
9 | from fairseq.tasks.translation import TranslationTask, TranslationConfig
10 |
11 | try:
12 | import examples.simultaneous_translation # noqa
13 |
14 | import_successful = True
15 | except BaseException:
16 | import_successful = False
17 |
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def check_import(flag):
23 | if not flag:
24 | raise ImportError(
25 | "'examples.simultaneous_translation' is not correctly imported. "
26 | "Please considering `pip install -e $FAIRSEQ_DIR`."
27 | )
28 |
29 |
30 | @register_task("simul_speech_to_text")
31 | class SimulSpeechToTextTask(SpeechToTextTask):
32 | def __init__(self, args, tgt_dict):
33 | check_import(import_successful)
34 | super().__init__(args, tgt_dict)
35 |
36 |
37 | @register_task("simul_text_to_text", dataclass=TranslationConfig)
38 | class SimulTextToTextTask(TranslationTask):
39 | def __init__(self, cfg, src_dict, tgt_dict):
40 | check_import(import_successful)
41 | super().__init__(cfg, src_dict, tgt_dict)
42 |
--------------------------------------------------------------------------------
/fairseq/tasks/translation_from_pretrained_xlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass
7 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
8 | from fairseq.tasks.translation import TranslationConfig, TranslationTask
9 |
10 | from . import register_task
11 |
12 |
13 | @dataclass
14 | class TranslationFromPretrainedXLMConfig(TranslationConfig):
15 | pass
16 |
17 |
18 | @register_task(
19 | "translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig
20 | )
21 | class TranslationFromPretrainedXLMTask(TranslationTask):
22 | """
23 | Same as TranslationTask except use the MaskedLMDictionary class so that
24 | we can load data that was binarized with the MaskedLMDictionary class.
25 |
26 | This task should be used for the entire training pipeline when we want to
27 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
28 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation
29 | of that trained model.
30 | """
31 |
32 | @classmethod
33 | def load_dictionary(cls, filename):
34 | """Load the masked LM dictionary from the filename
35 |
36 | Args:
37 | filename (str): the filename
38 | """
39 | return MaskedLMDictionary.load(filename)
40 |
--------------------------------------------------------------------------------
/fairseq/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import re
7 |
8 |
9 | SPACE_NORMALIZER = re.compile(r"\s+")
10 |
11 |
12 | def tokenize_line(line):
13 | line = SPACE_NORMALIZER.sub(" ", line)
14 | line = line.strip()
15 | return line.split()
16 |
--------------------------------------------------------------------------------
/fairseq/version.txt:
--------------------------------------------------------------------------------
1 | 0.12.2
--------------------------------------------------------------------------------
/fairseq_cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/fairseq_cli/__init__.py
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel", "cython"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/compare_namespaces.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Helper script to compare two argparse.Namespace objects."""
3 |
4 | from argparse import Namespace # noqa
5 |
6 |
7 | def main():
8 |
9 | ns1 = eval(input("Namespace 1: "))
10 | ns2 = eval(input("Namespace 2: "))
11 |
12 | def keys(ns):
13 | ks = set()
14 | for k in dir(ns):
15 | if not k.startswith("_"):
16 | ks.add(k)
17 | return ks
18 |
19 | k1 = keys(ns1)
20 | k2 = keys(ns2)
21 |
22 | def print_keys(ks, ns1, ns2=None):
23 | for k in ks:
24 | if ns2 is None:
25 | print("{}\t{}".format(k, getattr(ns1, k, None)))
26 | else:
27 | print(
28 | "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None))
29 | )
30 |
31 | print("Keys unique to namespace 1:")
32 | print_keys(k1 - k2, ns1)
33 | print()
34 |
35 | print("Keys unique to namespace 2:")
36 | print_keys(k2 - k1, ns2)
37 | print()
38 |
39 | print("Overlapping keys with different values:")
40 | ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")]
41 | print_keys(ks, ns1, ns2)
42 | print()
43 |
44 |
45 | if __name__ == "__main__":
46 | main()
47 |
--------------------------------------------------------------------------------
/scripts/compound_split_bleu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 1 ]; then
4 | echo "usage: $0 GENERATE_PY_OUTPUT"
5 | exit 1
6 | fi
7 |
8 | GEN=$1
9 |
10 | SYS=$GEN.sys
11 | REF=$GEN.ref
12 |
13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
14 | echo "not done generating"
15 | exit
16 | fi
17 |
18 | grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
20 | fairseq-score --sys $SYS --ref $REF
21 |
--------------------------------------------------------------------------------
/scripts/constraints/validate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import sys
9 |
10 |
11 | """Reads in a fairseq output file, and verifies that the constraints
12 | (C- lines) are present in the output (the first H- line). Assumes that
13 | constraints are listed prior to the first hypothesis.
14 | """
15 |
16 | constraints = []
17 | found = 0
18 | total = 0
19 | for line in sys.stdin:
20 | if line.startswith("C-"):
21 | constraints.append(line.rstrip().split("\t")[1])
22 | elif line.startswith("H-"):
23 | text = line.split("\t")[2]
24 |
25 | for constraint in constraints:
26 | total += 1
27 | if constraint in text:
28 | found += 1
29 | else:
30 | print(f"No {constraint} in {text}", file=sys.stderr)
31 |
32 | constraints = []
33 |
34 | print(f"Found {found} / {total} = {100 * found / total:.1f}%")
35 |
--------------------------------------------------------------------------------
/scripts/convert_dictionary.lua:
--------------------------------------------------------------------------------
1 | -- Copyright (c) Facebook, Inc. and its affiliates.
2 | --
3 | -- This source code is licensed under the MIT license found in the
4 | -- LICENSE file in the root directory of this source tree.
5 | --
6 | -- Usage: convert_dictionary.lua
7 | require 'fairseq'
8 | require 'torch'
9 | require 'paths'
10 |
11 | if #arg < 1 then
12 | print('usage: convert_dictionary.lua ')
13 | os.exit(1)
14 | end
15 | if not paths.filep(arg[1]) then
16 | print('error: file does not exit: ' .. arg[1])
17 | os.exit(1)
18 | end
19 |
20 | dict = torch.load(arg[1])
21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt')
22 | assert(dst:match('.txt$'))
23 |
24 | f = io.open(dst, 'w')
25 | for idx, symbol in ipairs(dict.index_to_symbol) do
26 | if idx > dict.cutoff then
27 | break
28 | end
29 | f:write(symbol)
30 | f:write(' ')
31 | f:write(dict.index_to_freq[idx])
32 | f:write('\n')
33 | end
34 | f:close()
35 |
--------------------------------------------------------------------------------
/scripts/read_binarized.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 |
9 | from fairseq.data import Dictionary, data_utils, indexed_dataset
10 |
11 |
12 | def get_parser():
13 | parser = argparse.ArgumentParser(
14 | description="writes text from binarized file to stdout"
15 | )
16 | # fmt: off
17 | parser.add_argument('--dataset-impl', help='dataset implementation',
18 | choices=indexed_dataset.get_available_dataset_impl())
19 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
20 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
21 | # fmt: on
22 |
23 | return parser
24 |
25 |
26 | def main():
27 | parser = get_parser()
28 | args = parser.parse_args()
29 |
30 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None
31 | dataset = data_utils.load_indexed_dataset(
32 | args.input,
33 | dictionary,
34 | dataset_impl=args.dataset_impl,
35 | default="lazy",
36 | )
37 |
38 | for tensor_line in dataset:
39 | if dictionary is None:
40 | line = " ".join([str(int(x)) for x in tensor_line])
41 | else:
42 | line = dictionary.string(tensor_line)
43 |
44 | print(line)
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/scripts/sacrebleu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 4 ]; then
4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
5 | exit 1
6 | fi
7 |
8 | TESTSET=$1
9 | SRCLANG=$2
10 | TGTLANG=$3
11 |
12 | GEN=$4
13 |
14 | if ! command -v sacremoses &> /dev/null
15 | then
16 | echo "sacremoses could not be found, please install with: pip install sacremoses"
17 | exit
18 | fi
19 |
20 | grep ^H $GEN \
21 | | sed 's/^H\-//' \
22 | | sort -n -k 1 \
23 | | cut -f 3 \
24 | | sacremoses detokenize \
25 | > $GEN.sorted.detok
26 |
27 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok
28 |
--------------------------------------------------------------------------------
/scripts/spm_decode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import argparse
11 |
12 | import sentencepiece as spm
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument(
18 | "--model", required=True, help="sentencepiece model to use for decoding"
19 | )
20 | parser.add_argument("--input", required=True, help="input file to decode")
21 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
22 | args = parser.parse_args()
23 |
24 | sp = spm.SentencePieceProcessor()
25 | sp.Load(args.model)
26 |
27 | if args.input_format == "piece":
28 |
29 | def decode(input):
30 | return "".join(sp.DecodePieces(input))
31 |
32 | elif args.input_format == "id":
33 |
34 | def decode(input):
35 | return "".join(sp.DecodeIds(input))
36 |
37 | else:
38 | raise NotImplementedError
39 |
40 | def tok2int(tok):
41 | # remap reference-side (represented as <>) to 0
42 | return int(tok) if tok != "<>" else 0
43 |
44 | with open(args.input, "r", encoding="utf-8") as h:
45 | for line in h:
46 | if args.input_format == "id":
47 | print(decode(list(map(tok2int, line.rstrip().split()))))
48 | elif args.input_format == "piece":
49 | print(decode(line.rstrip().split()))
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/scripts/spm_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import sys
11 |
12 | import sentencepiece as spm
13 |
14 |
15 | if __name__ == "__main__":
16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
17 |
--------------------------------------------------------------------------------
/scripts/test_fsdp.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | rm -rf fsdp_dummy
3 | mkdir -p fsdp_dummy
4 | CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
5 | --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
6 | --cpu-offload --checkpoint-activations \
7 | --task language_modeling --tokens-per-sample 256 --batch-size 8 \
8 | --arch transformer_lm_gpt2_tiny \
9 | --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
10 | --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
11 | --max-update 5 --log-format json --log-interval 1 \
12 | --save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \
13 | --restore-file x.pt "$@"
14 |
15 | # Now we try to load the checkpoint
16 | CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
17 | --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
18 | --cpu-offload --checkpoint-activations \
19 | --task language_modeling --tokens-per-sample 256 --batch-size 8 \
20 | --arch transformer_lm_gpt2_tiny \
21 | --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
22 | --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
23 | --max-update 2 --log-format json --log-interval 1 \
24 | --save-interval-updates 2 --save-dir fsdp_dummy
25 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 127
3 | extend-ignore = E203, W503
4 | extend-exclude = fairseq/model_parallel/megatron
5 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/tests/__init__.py
--------------------------------------------------------------------------------
/tests/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/tests/distributed/__init__.py
--------------------------------------------------------------------------------
/tests/distributed/test_distributed_timeout_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import logging
7 | import signal
8 | import time
9 | import unittest
10 |
11 | import torch
12 | from torch import nn
13 |
14 | from fairseq.distributed import DistributedTimeoutWrapper
15 |
16 |
17 | class ModuleWithDelay(nn.Module):
18 | def __init__(self, delay):
19 | super().__init__()
20 | self.delay = delay
21 |
22 | def forward(self, x):
23 | time.sleep(self.delay)
24 | return x
25 |
26 |
27 | class TestDistributedTimeoutWrapper(unittest.TestCase):
28 | def setUp(self):
29 | logging.disable(logging.CRITICAL)
30 |
31 | def tearDown(self):
32 | logging.disable(logging.NOTSET)
33 |
34 | def test_no_timeout(self):
35 | module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
36 | module(torch.rand(5))
37 | module.stop_timeout()
38 |
39 | def test_timeout_safe(self):
40 | module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
41 | module(torch.rand(5))
42 | module.stop_timeout()
43 |
44 | def test_timeout_killed(self):
45 | with self.assertRaises(KeyboardInterrupt):
46 | module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
47 | module(torch.rand(5))
48 | module.stop_timeout()
49 |
50 |
51 | if __name__ == "__main__":
52 | unittest.main()
53 |
--------------------------------------------------------------------------------
/tests/gpu/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/tests/gpu/__init__.py
--------------------------------------------------------------------------------
/tests/gpu/transformer_quantization_config.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # This file defines example configuration arguments for quantizing
7 | # a transformer model with product quantization
8 |
9 | n_centroids:
10 | Linear:
11 | key: in_features
12 | value: {"*": 8}
13 | Embedding:
14 | key: embedding_dim
15 | value: {"*": 8}
16 |
17 | block_sizes:
18 | Linear:
19 | key: fuzzy_name
20 | value: {fc: 8, attn: 4, emb: 4}
21 | Embedding:
22 | key: fuzzy_name
23 | value: {emb: 8}
24 |
25 | layers_to_quantize:
26 | - decoder\\.layers\\.\d+\\.fc[12]
27 | - decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]
28 | - decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)
29 |
--------------------------------------------------------------------------------
/tests/speech/test_convtransformer_simul_trans.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 | from tests.speech import TestFairseqSpeech
8 |
9 | S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
10 |
11 |
12 | class TestConvtransformerSimulTrans(TestFairseqSpeech):
13 | def setUp(self):
14 | self._set_up(
15 | "simul",
16 | "speech_tests/simul",
17 | ["config_gcmvn_specaug.yaml", "dict.txt", "dev.tsv"],
18 | )
19 |
20 | def test_waitk_checkpoint(self):
21 | """Only test model loading since fairseq currently doesn't support inference of simultaneous models"""
22 | _, _, _, _ = self.download_and_load_checkpoint(
23 | "checkpoint_best.pt",
24 | arg_overrides={
25 | "config_yaml": "config_gcmvn_specaug.yaml",
26 | "load_pretrained_encoder_from": None,
27 | },
28 | )
29 | return
30 |
31 |
32 | if __name__ == "__main__":
33 | unittest.main()
34 |
--------------------------------------------------------------------------------
/tests/speech/test_s2t_conformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 | from tests.speech import TestFairseqSpeech
8 |
9 |
10 | class TestS2TConformer(TestFairseqSpeech):
11 | def setUp(self):
12 | self.set_up_librispeech()
13 |
14 | def test_librispeech_s2t_conformer_s_checkpoint(self):
15 | self.base_test(
16 | ckpt_name="librispeech_conformer_rel_pos_s.pt",
17 | reference_score=12,
18 | arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | unittest.main()
24 |
--------------------------------------------------------------------------------
/tests/speech/test_s2t_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 | from tests.speech import TestFairseqSpeech
8 |
9 |
10 | class TestS2TTransformer(TestFairseqSpeech):
11 | def setUp(self):
12 | self.set_up_librispeech()
13 |
14 | def test_librispeech_s2t_transformer_s_checkpoint(self):
15 | self.base_test(
16 | ckpt_name="librispeech_transformer_s.pt",
17 | reference_score=9,
18 | arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | unittest.main()
24 |
--------------------------------------------------------------------------------
/tests/speech/test_xm_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 | from tests.speech import TestFairseqSpeech
8 |
9 |
10 | class TestXMTransformer(TestFairseqSpeech):
11 | def setUp(self):
12 | self.set_up_sotasty_es_en()
13 |
14 | # TODO: investigate increases BLEU score (30.42 -> 31.74)
15 | def test_sotasty_es_en_600m_checkpoint(self):
16 | self.base_test(
17 | ckpt_name="xm_transformer_600m_es_en_md.pt",
18 | reference_score=31.74,
19 | score_delta=0.2,
20 | max_tokens=3_000_000,
21 | max_positions=(1_000_000, 1_024),
22 | dataset="sotasty_es_en_test_ted",
23 | arg_overrides={"config_yaml": "cfg_es_en.yaml"},
24 | score_type="bleu",
25 | )
26 |
27 |
28 | if __name__ == "__main__":
29 | unittest.main()
30 |
--------------------------------------------------------------------------------
/tests/speech_recognition/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Speech-Lab-IITM/data2vec-aqc/c80cc996669e12c9c32201a9ff1515300e64b13a/tests/speech_recognition/__init__.py
--------------------------------------------------------------------------------
/tests/speech_recognition/test_cross_entropy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from examples.speech_recognition.criterions.cross_entropy_acc import (
8 | CrossEntropyWithAccCriterion,
9 | )
10 |
11 | from .asr_test_base import CrossEntropyCriterionTestBase
12 |
13 |
14 | class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
15 | def setUp(self):
16 | self.criterion_cls = CrossEntropyWithAccCriterion
17 | super().setUp()
18 |
19 | def test_cross_entropy_all_correct(self):
20 | sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
21 | loss, sample_size, logging_output = self.criterion(
22 | self.model, sample, "sum", log_probs=True
23 | )
24 | assert logging_output["correct"] == 20
25 | assert logging_output["total"] == 20
26 | assert logging_output["sample_size"] == 20
27 | assert logging_output["ntokens"] == 20
28 |
29 | def test_cross_entropy_all_wrong(self):
30 | sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
31 | loss, sample_size, logging_output = self.criterion(
32 | self.model, sample, "sum", log_probs=True
33 | )
34 | assert logging_output["correct"] == 0
35 | assert logging_output["total"] == 20
36 | assert logging_output["sample_size"] == 20
37 | assert logging_output["ntokens"] == 20
38 |
--------------------------------------------------------------------------------
/tests/test_hf_hub.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import unittest
8 |
9 | import torch
10 |
11 | try:
12 | import huggingface_hub
13 | except ImportError:
14 | huggingface_hub = None
15 |
16 | from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
17 |
18 |
19 | @unittest.skipIf(not huggingface_hub, "Requires huggingface_hub install")
20 | class TestHuggingFaceHub(unittest.TestCase):
21 | @torch.no_grad()
22 | def test_hf_fastspeech2(self):
23 | hf_model_id = "facebook/fastspeech2-en-ljspeech"
24 | models, cfg, task = load_model_ensemble_and_task_from_hf_hub(hf_model_id)
25 | self.assertTrue(len(models) > 0)
26 |
27 |
28 | if __name__ == "__main__":
29 | unittest.main()
30 |
--------------------------------------------------------------------------------
/tests/test_iopath.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import unittest
7 | from unittest import mock
8 |
9 |
10 | class TestIOPath(unittest.TestCase):
11 | def test_no_iopath(self):
12 | from .test_reproducibility import TestReproducibility
13 |
14 | with mock.patch.dict("sys.modules", {"iopath": None}):
15 | # reuse reproducibility tests, which are e2e tests that should cover
16 | # most checkpoint related functionality
17 | TestReproducibility._test_reproducibility(self, "test_reproducibility")
18 |
19 | def test_no_supports_rename(self):
20 | from .test_reproducibility import TestReproducibility
21 |
22 | with mock.patch("fairseq.file_io.PathManager.supports_rename") as mock_fn:
23 | mock_fn.return_value = False
24 | TestReproducibility._test_reproducibility(self, "test_reproducibility")
25 |
26 |
27 | if __name__ == "__main__":
28 | unittest.main()
29 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead.
8 | """
9 |
10 | from fairseq_cli.train import cli_main
11 |
12 |
13 | if __name__ == "__main__":
14 | cli_main()
15 |
--------------------------------------------------------------------------------