├── LICENSE ├── README.md └── codes ├── .gitignore ├── build_faiss_index_knn_align.sh ├── build_faiss_index_knn_align_pruned.sh ├── config ├── config.yaml ├── config_eval_lm.yaml ├── criterion │ ├── adaptive_loss.yaml │ └── cross_entropy.yaml ├── lr_scheduler │ ├── cosine.yaml │ └── inverse_sqrt.yaml ├── model │ ├── transformer_lm.yaml │ ├── transformer_lm_baevski_gbw.yaml │ ├── transformer_lm_baevski_wiki103.yaml │ ├── transformer_lm_big.yaml │ ├── transformer_lm_gbw.yaml │ ├── transformer_lm_gpt.yaml │ ├── transformer_lm_gpt2_big.yaml │ ├── transformer_lm_gpt2_medium.yaml │ ├── transformer_lm_gpt2_small.yaml │ └── transformer_lm_wiki103.yaml ├── optimizer │ ├── adam.yaml │ └── nag.yaml ├── params │ ├── eval_lm_params.yaml │ └── training_params.yaml └── task │ └── language_modeling.yaml ├── create_datastore.sh ├── create_datastore_knn_align.sh ├── examples └── __init__.py ├── experimental_knn_align.py ├── fairseq ├── __init__.py ├── benchmark │ ├── __init__.py │ ├── dummy_lm.py │ ├── dummy_masked_lm.py │ ├── dummy_model.py │ └── dummy_mt.py ├── binarizer.py ├── checkpoint_utils.py ├── clib │ ├── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp │ ├── libnat │ │ └── edit_dist.cpp │ └── libnat_cuda │ │ ├── binding.cpp │ │ ├── edit_dist.cu │ │ └── edit_dist.h ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── composite_loss.py │ ├── cross_entropy.py │ ├── ctc.py │ ├── fairseq_criterion.py │ ├── label_smoothed_cross_entropy.py │ ├── label_smoothed_cross_entropy_with_alignment.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── nat_loss.py │ ├── sentence_prediction.py │ ├── sentence_ranking.py │ ├── triplet_ranking.py │ └── wav2vec_criterion.py ├── data │ ├── __init__.py │ ├── add_target_dataset.py │ ├── append_token_dataset.py │ ├── audio │ │ ├── __init__.py │ │ ├── audio_utils.py │ │ ├── feature_transforms │ │ │ ├── __init__.py │ │ │ ├── global_cmvn.py │ │ │ ├── specaugment.py │ │ │ └── utterance_cmvn.py │ │ ├── raw_audio_dataset.py │ │ └── speech_to_text_dataset.py │ ├── backtranslation_dataset.py │ ├── base_wrapper_dataset.py │ ├── bucket_pad_length_dataset.py │ ├── colorize_dataset.py │ ├── concat_dataset.py │ ├── concat_sentences_dataset.py │ ├── data_utils.py │ ├── data_utils_fast.cpp │ ├── data_utils_fast.pyx │ ├── denoising_dataset.py │ ├── dictionary.py │ ├── encoders │ │ ├── __init__.py │ │ ├── byte_bpe.py │ │ ├── byte_utils.py │ │ ├── bytes.py │ │ ├── characters.py │ │ ├── fastbpe.py │ │ ├── gpt2_bpe.py │ │ ├── gpt2_bpe_utils.py │ │ ├── hf_bert_bpe.py │ │ ├── hf_byte_bpe.py │ │ ├── moses_tokenizer.py │ │ ├── nltk_tokenizer.py │ │ ├── sentencepiece_bpe.py │ │ ├── space_tokenizer.py │ │ ├── subword_nmt_bpe.py │ │ └── utils.py │ ├── fairseq_dataset.py │ ├── fasta_dataset.py │ ├── id_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── legacy │ │ ├── __init__.py │ │ ├── block_pair_dataset.py │ │ ├── masked_lm_dataset.py │ │ └── masked_lm_dictionary.py │ ├── list_dataset.py │ ├── lm_context_window_dataset.py │ ├── lru_cache_dataset.py │ ├── mask_tokens_dataset.py │ ├── monolingual_dataset.py │ ├── multi_corpus_dataset.py │ ├── multi_corpus_sampled_dataset.py │ ├── multilingual │ │ ├── __init__.py │ │ ├── multilingual_data_manager.py │ │ ├── multilingual_utils.py │ │ ├── sampled_multi_dataset.py │ │ ├── sampled_multi_epoch_dataset.py │ │ └── sampling_method.py │ ├── nested_dictionary_dataset.py │ ├── noising.py │ ├── num_samples_dataset.py │ ├── numel_dataset.py │ ├── offset_tokens_dataset.py │ ├── pad_dataset.py │ ├── plasma_utils.py │ ├── prepend_dataset.py │ ├── prepend_token_dataset.py │ ├── raw_label_dataset.py │ ├── replace_dataset.py │ ├── resampling_dataset.py │ ├── roll_dataset.py │ ├── round_robin_zip_datasets.py │ ├── shorten_dataset.py │ ├── sort_dataset.py │ ├── strip_token_dataset.py │ ├── subsample_dataset.py │ ├── token_block_dataset.py │ ├── token_block_utils_fast.cpp │ ├── token_block_utils_fast.pyx │ ├── transform_eos_dataset.py │ └── transform_eos_lang_pair_dataset.py ├── data_utils_fast.cpython-36m-darwin.so ├── data_utils_fast.cpython-36m-x86_64-linux-gnu.so ├── data_utils_fast.cpython-37m-x86_64-linux-gnu.so ├── dataclass │ ├── __init__.py │ ├── constants.py │ ├── data_class.py │ └── utils.py ├── distributed_utils.py ├── file_io.py ├── file_utils.py ├── hub_utils.py ├── incremental_decoding_utils.py ├── iterative_refinement_generator.py ├── legacy_distributed_data_parallel.py ├── logging │ ├── __init__.py │ ├── meters.py │ ├── metrics.py │ └── progress_bar.py ├── model_parallel │ ├── __init__.py │ ├── criterions │ │ ├── __init__.py │ │ └── vocab_parallel_cross_entropy.py │ ├── megatron_trainer.py │ ├── models │ │ ├── __init__.py │ │ ├── pipeline_parallel_transformer │ │ │ ├── __init__.py │ │ │ ├── layers.py │ │ │ └── model.py │ │ ├── roberta │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── transformer.py │ │ └── transformer_lm.py │ └── modules │ │ ├── __init__.py │ │ ├── multihead_attention.py │ │ ├── transformer_layer.py │ │ ├── transformer_sentence_encoder.py │ │ └── transformer_sentence_encoder_layer.py ├── models │ ├── __init__.py │ ├── bart │ │ ├── __init__.py │ │ ├── hub_interface.py │ │ └── model.py │ ├── composite_encoder.py │ ├── distributed_fairseq_model.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_lm.py │ ├── fconv_self_att.py │ ├── huggingface │ │ ├── __init__.py │ │ └── hf_gpt2.py │ ├── lightconv.py │ ├── lightconv_lm.py │ ├── lstm.py │ ├── lstm_lm.py │ ├── masked_lm.py │ ├── model_utils.py │ ├── multilingual_transformer.py │ ├── nat │ │ ├── __init__.py │ │ ├── cmlm_transformer.py │ │ ├── fairseq_nat_model.py │ │ ├── insertion_transformer.py │ │ ├── iterative_nonautoregressive_transformer.py │ │ ├── levenshtein_transformer.py │ │ ├── levenshtein_utils.py │ │ ├── nat_crf_transformer.py │ │ ├── nonautoregressive_ensembles.py │ │ └── nonautoregressive_transformer.py │ ├── roberta │ │ ├── __init__.py │ │ ├── alignment_utils.py │ │ ├── hub_interface.py │ │ ├── model.py │ │ ├── model_camembert.py │ │ └── model_xlmr.py │ ├── speech_to_text │ │ ├── __init__.py │ │ ├── berard.py │ │ └── s2t_transformer.py │ ├── transformer.py │ ├── transformer_align.py │ ├── transformer_from_pretrained_xlm.py │ ├── transformer_knn_projection.py │ ├── transformer_lm.py │ └── wav2vec │ │ ├── __init__.py │ │ ├── wav2vec.py │ │ ├── wav2vec2.py │ │ └── wav2vec2_asr.py ├── modules │ ├── __init__.py │ ├── adaptive_input.py │ ├── adaptive_softmax.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── cross_entropy.py │ ├── cuda_utils.cu │ ├── downsampled_multihead_attention.py │ ├── dynamic_convolution.py │ ├── dynamic_crf_layer.py │ ├── dynamicconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── dynamicconv_cuda.cpp │ │ ├── dynamicconv_cuda.cuh │ │ ├── dynamicconv_cuda_kernel.cu │ │ ├── dynamicconv_layer.py │ │ ├── dynamiconv_cpu.cpp │ │ └── setup.py │ ├── fairseq_dropout.py │ ├── fp32_group_norm.py │ ├── gelu.py │ ├── grad_multiply.py │ ├── gumbel_vector_quantizer.py │ ├── kmeans_vector_quantizer.py │ ├── knn_datastore.py │ ├── layer_drop.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── lightconv_cuda.cpp │ │ ├── lightconv_cuda.cuh │ │ ├── lightconv_cuda_kernel.cu │ │ ├── lightconv_layer.py │ │ └── setup.py │ ├── lightweight_convolution.py │ ├── linearized_convolution.py │ ├── multihead_attention.py │ ├── positional_embedding.py │ ├── quant_noise.py │ ├── quantization │ │ ├── __init__.py │ │ ├── pq │ │ │ ├── __init__.py │ │ │ ├── em.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── qconv.py │ │ │ │ ├── qemb.py │ │ │ │ └── qlinear.py │ │ │ ├── pq.py │ │ │ └── utils.py │ │ ├── quantization_options.py │ │ └── scalar │ │ │ ├── __init__.py │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── qact.py │ │ │ ├── qconv.py │ │ │ ├── qemb.py │ │ │ └── qlinear.py │ │ │ ├── ops.py │ │ │ └── utils.py │ ├── same_pad.py │ ├── scalar_bias.py │ ├── sinusoidal_positional_embedding.py │ ├── sparse_multihead_attention.py │ ├── sparse_transformer_sentence_encoder.py │ ├── sparse_transformer_sentence_encoder_layer.py │ ├── transformer_layer.py │ ├── transformer_sentence_encoder.py │ ├── transformer_sentence_encoder_layer.py │ ├── transpose_last.py │ ├── unfold.py │ └── vggblock.py ├── nan_detector.py ├── optim │ ├── __init__.py │ ├── adadelta.py │ ├── adafactor.py │ ├── adagrad.py │ ├── adam.py │ ├── adamax.py │ ├── bmuf.py │ ├── dynamic_loss_scaler.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── fused_adam.py │ ├── fused_lamb.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── polynomial_decay_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ ├── tri_stage_lr_scheduler.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ ├── sgd.py │ └── shard.py ├── options.py ├── pdb.py ├── quantization_utils.py ├── registry.py ├── scoring │ ├── __init__.py │ ├── bleu.py │ ├── chrf.py │ ├── tokenizer.py │ └── wer.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── audio_pretraining.py │ ├── cross_lingual_lm.py │ ├── denoising.py │ ├── fairseq_task.py │ ├── language_modeling.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── multilingual_denoising.py │ ├── multilingual_masked_lm.py │ ├── multilingual_translation.py │ ├── semisupervised_translation.py │ ├── sentence_prediction.py │ ├── sentence_ranking.py │ ├── speech_to_text.py │ ├── translation.py │ ├── translation_from_pretrained_bart.py │ ├── translation_from_pretrained_xlm.py │ ├── translation_lev.py │ └── translation_multi_simple_epoch.py ├── token_block_utils_fast.cpython-36m-darwin.so ├── token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so ├── token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so ├── token_generation_constraints.py ├── tokenizer.py ├── trainer.py └── utils.py ├── fairseq_cli ├── __init__.py ├── eval_lm.py ├── generate.py ├── interactive.py ├── knn_align.py ├── preprocess.py ├── score.py ├── train.py ├── train_ddp.py └── validate.py ├── knn_align.sh ├── prune_datastore.py ├── save_datastore.py ├── save_datastore_knn_align.py ├── setup.py ├── test_adaptive_knn_mt_knn_align.sh ├── test_adaptive_knn_mt_knn_align_pruned.sh ├── train_datastore_gpu.py ├── train_faiss_knn_align.sh ├── train_faiss_knn_align_ddp.sh └── train_faiss_knn_align_pruned.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 WonderSeen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /codes/.gitignore: -------------------------------------------------------------------------------- 1 | save_datastore 2 | data-bin 3 | nohup* 4 | model_path* 5 | */__pycache__ 6 | *.pyc 7 | pretrain_model 8 | model_record_path 9 | build -------------------------------------------------------------------------------- /codes/build_faiss_index_knn_align.sh: -------------------------------------------------------------------------------- 1 | 2 | postfix=_nce 3 | DOMAIN=(law it koran medical subtitles) 4 | DSTORE_SIZE=(19061383 3602863 524375 6903142 153604142) 5 | gpu_ids=(0 1 2 3 4) 6 | 7 | PROJECT_PATH=. 8 | for idx in ${!gpu_ids[*]} 9 | do 10 | DSTORE_PATH=save_datastore/${DOMAIN[$idx]}/knn_transfered${postfix} 11 | CUDA_VISIBLE_DEVICES=${gpu_ids[$idx]} python ${PROJECT_PATH}/train_datastore_gpu.py \ 12 | --dstore_mmap ${DSTORE_PATH} \ 13 | --dstore_size ${DSTORE_SIZE[$idx]} \ 14 | --faiss_index ${DSTORE_PATH}/knn_index \ 15 | --ncentroids 4096 \ 16 | --probe 32 \ 17 | --dimension 64 \ 18 | --dstore_fp16 \ 19 | --use-gpu \ 20 | > nohup-${DOMAIN[$idx]}/build_faiss_index${postfix}.txt 2>&1 & 21 | echo nohup-${DOMAIN[$idx]}/build_faiss_index${postfix}.txt 22 | done 23 | -------------------------------------------------------------------------------- /codes/build_faiss_index_knn_align_pruned.sh: -------------------------------------------------------------------------------- 1 | PROJECT_PATH=. 2 | DOMAINS=(subtitles) 3 | DSTORE_SIZES=(137803715) # depends on the size of the pruned datastore (various pruning rates) 4 | 5 | gpu_ids=(0) 6 | 7 | for idx in ${!gpu_ids[*]} 8 | do 9 | LOG= 10 | DSTORE_PATH= 11 | CUDA_VISIBLE_DEVICES=${gpu_ids[$idx]} python ${PROJECT_PATH}/train_datastore_gpu.py \ 12 | --dstore_mmap ${DSTORE_PATH} \ 13 | --dstore_size ${DSTORE_SIZES[$idx]} \ 14 | --faiss_index ${DSTORE_PATH}/knn_index \ 15 | --ncentroids 4096 \ 16 | --probe 32 \ 17 | --dimension 64 \ 18 | --dstore_fp16 \ 19 | > ${LOG} 2>&1 & 20 | echo ${LOG} 21 | # --use-gpu \ 22 | done -------------------------------------------------------------------------------- /codes/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - params: training_params 3 | - task: language_modeling 4 | - model: transformer_lm 5 | - criterion: cross_entropy 6 | - optimizer: adam 7 | - lr_scheduler: inverse_sqrt 8 | -------------------------------------------------------------------------------- /codes/config/config_eval_lm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - params: eval_lm_params 3 | - task: language_modeling 4 | - model: transformer_lm 5 | - criterion: cross_entropy 6 | - optimizer: adam 7 | - lr_scheduler: inverse_sqrt 8 | -------------------------------------------------------------------------------- /codes/config/criterion/adaptive_loss.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | sentence_avg: ${params.optimization.sentence_avg} 3 | ddp_backend: ${params.distributed_training.ddp_backend} 4 | -------------------------------------------------------------------------------- /codes/config/criterion/cross_entropy.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | sentence_avg: ${params.optimization.sentence_avg} 3 | ddp_backend: ${params.distributed_training.ddp_backend} 4 | -------------------------------------------------------------------------------- /codes/config/lr_scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | warmup_updates: 0 3 | warmup_init_lr: -1 4 | max_lr: 1.0 5 | t_mult: 1.0 6 | lr_period_updates: -1 7 | lr_shrink: 0.1 8 | -------------------------------------------------------------------------------- /codes/config/lr_scheduler/inverse_sqrt.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | warmup_updates: 4000 3 | warmup_init_lr: -1 4 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.0 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 2048 11 | decoder_layers: 6 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_baevski_gbw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_baevski_wiki103.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.3 4 | attention_dropout: 0.1 5 | activation_dropout: 0.1 6 | relu_dropout: 0.1 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 16 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: "20000,60000" 16 | adaptive_softmax_dropout: 0.2 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: true 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: "20000,60000" 27 | tie_adaptive_weights: true 28 | tie_adaptive_proj: true 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.0 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_gbw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_gpt.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 768 8 | decoder_output_dim: 768 9 | decoder_input_dim: 768 10 | decoder_ffn_embed_dim: 3072 11 | decoder_layers: 12 12 | decoder_attention_heads: 12 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_gpt2_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1600 8 | decoder_output_dim: 1600 9 | decoder_input_dim: 1600 10 | decoder_ffn_embed_dim: 6400 11 | decoder_layers: 48 12 | decoder_attention_heads: 25 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_gpt2_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1280 8 | decoder_output_dim: 1280 9 | decoder_input_dim: 1280 10 | decoder_ffn_embed_dim: 5120 11 | decoder_layers: 36 12 | decoder_attention_heads: 20 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_gpt2_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 24 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/model/transformer_lm_wiki103.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.3 4 | attention_dropout: 0.1 5 | activation_dropout: 0.1 6 | relu_dropout: 0.1 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 16 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: "20000,60000" 16 | adaptive_softmax_dropout: 0.2 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: true 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: "20000,60000" 27 | tie_adaptive_weights: true 28 | tie_adaptive_proj: true 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /codes/config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | adam_betas: "(0.9, 0.999)" 3 | adam_eps: 1.0e-8 4 | weight_decay: 0 5 | use_old_adam: false 6 | -------------------------------------------------------------------------------- /codes/config/optimizer/nag.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | momentum: 0.99 3 | weight_decay: 0.0 4 | -------------------------------------------------------------------------------- /codes/config/params/eval_lm_params.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | common: 3 | no_progress_bar: false 4 | log_interval: 100 5 | log_format: null 6 | tensorboard_logdir: null 7 | seed: 1 8 | cpu: false 9 | fp16: false 10 | memory_efficient_fp16: false 11 | fp16_no_flatten_grads: false 12 | fp16_init_scale: 128 13 | fp16_scale_window: null 14 | fp16_scale_tolerance: 0.0 15 | min_loss_scale: 1.0e-4 16 | threshold_loss_scale: null 17 | user_dir: null 18 | empty_cache_freq: 0 19 | all_gather_list_size: 16384 20 | model_parallel_size: 1 21 | checkpoint_suffix: "" 22 | quantization_config_path: null 23 | distributed_training: 24 | distributed_rank: 0 25 | distributed_backend: "nccl" 26 | distributed_init_method: null 27 | distributed_port: -1 28 | device_id: 0 29 | local_rank: 0 30 | distributed_no_spawn: false 31 | ddp_backend: "c10d" 32 | bucket_cap_mb: 25 33 | fix_batches_to_gpus: false 34 | find_unused_parameters: false 35 | fast_stat_sync: false 36 | broadcast_buffers: false 37 | distributed_wrapper: "DDP" 38 | slowmo_momentum: null 39 | slowmo_algorithm: "LocalSGD" 40 | localsgd_frequency: 3 41 | dataset: 42 | num_workers: 1 43 | skip_invalid_size_inputs_valid_test: false 44 | max_tokens: null 45 | batch_size: ${params.dataset.batch_size} 46 | required_batch_size_multiple: 8 47 | dataset_impl: null 48 | data_buffer_size: 10 49 | train_subset: "train" 50 | valid_subset: "valid" 51 | validate_interval: 1 52 | fixed_validation_seed: null 53 | disable_validation: false 54 | curriculum: 0 55 | gen_subset: "test" 56 | num_shards: 1 57 | shard_id: 0 58 | max_tokens_valid: ${params.dataset.max_tokens} 59 | batch_size_valid: ${params.dataset.batch_size} 60 | optimization: 61 | max_epoch: 0 62 | max_update: 0 63 | clip_norm: 25.0 64 | sentence_avg: false 65 | update_freq: [1] 66 | lr: [0.25] 67 | min_lr: -1.0 68 | use_bmuf: false 69 | checkpoint: 70 | save_dir: "checkpoints" 71 | restore_file: "checkpoint_last.pt" 72 | reset_dataloader: false 73 | reset_lr_scheduler: false 74 | reset_meters: false 75 | reset_optimizer: false 76 | optimizer_overrides: "{}" 77 | save_interval: 1 78 | save_interval_updates: 0 79 | keep_interval_updates: -1 80 | keep_last_epochs: -1 81 | keep_best_checkpoints: -1 82 | no_save: false 83 | no_epoch_checkpoints: false 84 | no_last_checkpoints: false 85 | no_save_optimizer_state: false 86 | best_checkpoint_metric: "loss" 87 | maximize_best_checkpoint_metric: false 88 | patience: -1 89 | common_eval: 90 | path: null 91 | remove_bpe: null 92 | quiet: false 93 | model_overrides: "{}" 94 | results_path: null 95 | eval_lm: 96 | output_word_probs: false 97 | output_word_stats: false 98 | context_window: 0 99 | bmuf: 100 | block_lr: 1 101 | block_momentum: 0.875 102 | global_sync_iter: 50 103 | warmup_iterations: 500 104 | use_nbm: false 105 | average_sync: false 106 | -------------------------------------------------------------------------------- /codes/config/params/training_params.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | common: 3 | no_progress_bar: false 4 | log_interval: 100 5 | log_format: null 6 | tensorboard_logdir: null 7 | seed: 1 8 | cpu: false 9 | fp16: false 10 | memory_efficient_fp16: false 11 | fp16_no_flatten_grads: false 12 | fp16_init_scale: 128 13 | fp16_scale_window: null 14 | fp16_scale_tolerance: 0.0 15 | min_loss_scale: 1.0e-4 16 | threshold_loss_scale: null 17 | user_dir: null 18 | empty_cache_freq: 0 19 | all_gather_list_size: 16384 20 | model_parallel_size: 1 21 | checkpoint_suffix: "" 22 | quantization_config_path: null 23 | distributed_training: 24 | distributed_rank: 0 25 | distributed_backend: "nccl" 26 | distributed_init_method: null 27 | distributed_port: -1 28 | device_id: 0 29 | local_rank: 0 30 | distributed_no_spawn: false 31 | ddp_backend: "c10d" 32 | bucket_cap_mb: 25 33 | fix_batches_to_gpus: false 34 | find_unused_parameters: false 35 | fast_stat_sync: false 36 | broadcast_buffers: false 37 | distributed_wrapper: "DDP" 38 | slowmo_momentum: null 39 | slowmo_algorithm: "LocalSGD" 40 | localsgd_frequency: 3 41 | dataset: 42 | num_workers: 1 43 | skip_invalid_size_inputs_valid_test: false 44 | max_tokens: null 45 | batch_size: ${params.dataset.batch_size} 46 | required_batch_size_multiple: 8 47 | dataset_impl: null 48 | data_buffer_size: 10 49 | train_subset: "train" 50 | valid_subset: "valid" 51 | validate_interval: 1 52 | fixed_validation_seed: null 53 | disable_validation: false 54 | curriculum: 0 55 | gen_subset: "test" 56 | num_shards: 1 57 | shard_id: 0 58 | max_tokens_valid: ${params.dataset.max_tokens} 59 | batch_size_valid: ${params.dataset.batch_size} 60 | optimization: 61 | max_epoch: 0 62 | max_update: 0 63 | clip_norm: 25.0 64 | sentence_avg: false 65 | update_freq: [1] 66 | lr: [0.25] 67 | min_lr: -1.0 68 | use_bmuf: false 69 | checkpoint: 70 | save_dir: "checkpoints" 71 | restore_file: "checkpoint_last.pt" 72 | reset_dataloader: false 73 | reset_lr_scheduler: false 74 | reset_meters: false 75 | reset_optimizer: false 76 | optimizer_overrides: "{}" 77 | save_interval: 1 78 | save_interval_updates: 0 79 | keep_interval_updates: -1 80 | keep_last_epochs: -1 81 | keep_best_checkpoints: -1 82 | no_save: false 83 | no_epoch_checkpoints: false 84 | no_last_checkpoints: false 85 | no_save_optimizer_state: false 86 | best_checkpoint_metric: "loss" 87 | maximize_best_checkpoint_metric: false 88 | patience: -1 89 | bmuf: 90 | block_lr: 1 91 | block_momentum: 0.875 92 | global_sync_iter: 50 93 | warmup_iterations: 500 94 | use_nbm: false 95 | average_sync: false 96 | -------------------------------------------------------------------------------- /codes/config/task/language_modeling.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data: ??? 3 | sample_break_mode: "none" 4 | tokens_per_sample: 1024 5 | output_dictionary_size: -1 6 | self_target: false 7 | future_target: false 8 | past_target: false 9 | add_bos_token: false 10 | max_target_positions: null 11 | -------------------------------------------------------------------------------- /codes/create_datastore.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=pretrain_model/wmt19.de-en.ffn8192.pt 2 | 3 | DSTORE_SIZE=3613350 4 | DATA_PATH=data-bin/it 5 | DATASTORE_PATH=save_datastore/it 6 | PROJECT_PATH=. 7 | 8 | mkdir -p $DATASTORE_PATH 9 | 10 | CUDA_VISIBLE_DEVICES=6 python $PROJECT_PATH/save_datastore.py $DATA_PATH \ 11 | --dataset-impl mmap \ 12 | --task translation \ 13 | --valid-subset train \ 14 | --path $MODEL_PATH \ 15 | --max-tokens 4096 --skip-invalid-size-inputs-valid-test \ 16 | --decoder-embed-dim 1024 --dstore-fp16 \ 17 | --dstore-size $DSTORE_SIZE --dstore-mmap $DATASTORE_PATH 18 | 19 | # 4096 and 1024 depend on your device and model separately 20 | 21 | -------------------------------------------------------------------------------- /codes/create_datastore_knn_align.sh: -------------------------------------------------------------------------------- 1 | postfix=_nce 2 | gpu=2 3 | OUTDOMAIN=law 4 | DOMAIN=${OUTDOMAIN} 5 | COMPACT_DIM=64 6 | 7 | 8 | # corpus 9 | DATA_PATH=data-bin/${DOMAIN} 10 | 11 | 12 | # datastore 13 | declare -A DSTORE_SIZES_dict 14 | DSTORE_SIZES_dict=([it]="3613350" [medical]="6903320" [koran]="524400" [law]="19070000" [wiki]="47987250" [subtitles]="153604142") 15 | DSTORE_SIZE=${DSTORE_SIZES_dict[$OUTDOMAIN]} 16 | DATASTORE_PATH=save_datastore/${OUTDOMAIN}/knn_transfered${postfix} 17 | # rm -rf $DATASTORE_PATH 18 | mkdir -p $DATASTORE_PATH 19 | 20 | # model 21 | MODEL_PATH=model_record_path/${OUTDOMAIN}/knn_transfered${postfix}_${COMPACT_DIM}/checkpoint_best.pt 22 | 23 | 24 | # log 25 | NOHUP_DIR=nohup-${DOMAIN} 26 | mkdir -p ${NOHUP_DIR} 27 | 28 | 29 | # start 30 | PROJECT_PATH=. 31 | CUDA_VISIBLE_DEVICES=$gpu python $PROJECT_PATH/save_datastore_knn_align.py $DATA_PATH \ 32 | --dataset-impl mmap \ 33 | --task translation \ 34 | --valid-subset train \ 35 | --path $MODEL_PATH \ 36 | --max-tokens 4096 --skip-invalid-size-inputs-valid-test \ 37 | --decoder-embed-dim 1024 --dstore-fp16 --dstore-size $DSTORE_SIZE --dstore-mmap $DATASTORE_PATH \ 38 | --decoder-knn-compact-dim ${COMPACT_DIM} --save-knn-compate-feature --not-train-knn-compact-projection \ 39 | --dstore-filename $DATASTORE_PATH --create-knn-compact-projection \ 40 | > ${NOHUP_DIR}/create_datastore${postfix}.txt 2>&1 & 41 | echo ${NOHUP_DIR}/create_datastore${postfix}.txt 42 | -------------------------------------------------------------------------------- /codes/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/examples/__init__.py -------------------------------------------------------------------------------- /codes/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | __all__ = ["pdb"] 8 | __version__ = "0.10.1" 9 | 10 | import sys 11 | 12 | # backwards compatibility to support `from fairseq.meters import AverageMeter` 13 | from fairseq.logging import meters, metrics, progress_bar # noqa 14 | 15 | sys.modules["fairseq.meters"] = meters 16 | sys.modules["fairseq.metrics"] = metrics 17 | sys.modules["fairseq.progress_bar"] = progress_bar 18 | 19 | import fairseq.criterions # noqa 20 | import fairseq.models # noqa 21 | import fairseq.modules # noqa 22 | import fairseq.optim # noqa 23 | import fairseq.optim.lr_scheduler # noqa 24 | import fairseq.pdb # noqa 25 | import fairseq.scoring # noqa 26 | import fairseq.tasks # noqa 27 | import fairseq.token_generation_constraints # noqa 28 | 29 | import fairseq.benchmark # noqa 30 | import fairseq.model_parallel # noqa 31 | -------------------------------------------------------------------------------- /codes/fairseq/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # import models/tasks to register them 7 | from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa 8 | -------------------------------------------------------------------------------- /codes/fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /codes/fairseq/clib/libnat_cuda/binding.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /* 10 | This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance 11 | */ 12 | 13 | #include "edit_dist.h" 14 | #include 15 | 16 | #ifndef TORCH_CHECK 17 | #define TORCH_CHECK AT_CHECK 18 | #endif 19 | 20 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 21 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 22 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 23 | 24 | 25 | torch::Tensor LevenshteinDistance( 26 | torch::Tensor source, 27 | torch::Tensor target, 28 | torch::Tensor source_length, 29 | torch::Tensor target_length) { 30 | 31 | CHECK_INPUT(source); 32 | CHECK_INPUT(target); 33 | CHECK_INPUT(source_length); 34 | CHECK_INPUT(target_length); 35 | return LevenshteinDistanceCuda(source, target, source_length, target_length); 36 | } 37 | 38 | torch::Tensor GenerateDeletionLabel( 39 | torch::Tensor source, 40 | torch::Tensor operations) { 41 | 42 | CHECK_INPUT(source); 43 | CHECK_INPUT(operations); 44 | return GenerateDeletionLabelCuda(source, operations); 45 | } 46 | 47 | std::pair GenerateInsertionLabel( 48 | torch::Tensor target, 49 | torch::Tensor operations) { 50 | 51 | CHECK_INPUT(target); 52 | CHECK_INPUT(operations); 53 | return GenerateInsertionLabelCuda(target, operations); 54 | } 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); 58 | m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label"); 59 | m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label"); 60 | } 61 | -------------------------------------------------------------------------------- /codes/fairseq/clib/libnat_cuda/edit_dist.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | torch::Tensor LevenshteinDistanceCuda( 14 | torch::Tensor source, 15 | torch::Tensor target, 16 | torch::Tensor source_length, 17 | torch::Tensor target_length); 18 | 19 | torch::Tensor GenerateDeletionLabelCuda( 20 | torch::Tensor source, 21 | torch::Tensor operations); 22 | 23 | std::pair GenerateInsertionLabelCuda( 24 | torch::Tensor source, 25 | torch::Tensor operations); 26 | -------------------------------------------------------------------------------- /codes/fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | from argparse import Namespace 10 | from typing import Union 11 | 12 | from fairseq import registry 13 | from fairseq.criterions.fairseq_criterion import ( # noqa 14 | FairseqCriterion, 15 | LegacyFairseqCriterion, 16 | ) 17 | from omegaconf import DictConfig 18 | 19 | 20 | ( 21 | build_criterion_, 22 | register_criterion, 23 | CRITERION_REGISTRY, 24 | CRITERION_DATACLASS_REGISTRY, 25 | ) = registry.setup_registry( 26 | "--criterion", base_class=FairseqCriterion, default="cross_entropy" 27 | ) 28 | 29 | 30 | def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): 31 | return build_criterion_(criterion_cfg, task) 32 | 33 | 34 | # automatically import any Python files in the criterions/ directory 35 | for file in os.listdir(os.path.dirname(__file__)): 36 | if file.endswith(".py") and not file.startswith("_"): 37 | file_name = file[: file.find(".py")] 38 | importlib.import_module("fairseq.criterions." + file_name) 39 | -------------------------------------------------------------------------------- /codes/fairseq/data/add_target_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset, data_utils 9 | 10 | 11 | class AddTargetDataset(BaseWrapperDataset): 12 | def __init__( 13 | self, 14 | dataset, 15 | labels, 16 | pad, 17 | eos, 18 | batch_targets, 19 | process_label=None, 20 | add_to_input=False, 21 | ): 22 | super().__init__(dataset) 23 | self.labels = labels 24 | self.batch_targets = batch_targets 25 | self.pad = pad 26 | self.eos = eos 27 | self.process_label = process_label 28 | self.add_to_input = add_to_input 29 | 30 | def get_label(self, index): 31 | return ( 32 | self.labels[index] 33 | if self.process_label is None 34 | else self.process_label(self.labels[index]) 35 | ) 36 | 37 | def __getitem__(self, index): 38 | item = self.dataset[index] 39 | item["label"] = self.get_label(index) 40 | return item 41 | 42 | def size(self, index): 43 | sz = self.dataset.size(index) 44 | own_sz = len(self.get_label(index)) 45 | return (sz, own_sz) 46 | 47 | def collater(self, samples): 48 | collated = self.dataset.collater(samples) 49 | if len(collated) == 0: 50 | return collated 51 | indices = set(collated["id"].tolist()) 52 | target = [s["label"] for s in samples if s["id"] in indices] 53 | 54 | if self.batch_targets: 55 | collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) 56 | target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) 57 | collated["ntokens"] = collated["target_lengths"].sum().item() 58 | else: 59 | collated["ntokens"] = sum([len(t) for t in target]) 60 | 61 | collated["target"] = target 62 | 63 | if self.add_to_input: 64 | eos = target.new_full((target.size(0), 1), self.eos) 65 | collated["target"] = torch.cat([target, eos], dim=-1).long() 66 | collated["net_input"]["prev_output_tokens"] = torch.cat( 67 | [eos, target], dim=-1 68 | ).long() 69 | collated["ntokens"] += target.size(0) 70 | return collated 71 | -------------------------------------------------------------------------------- /codes/fairseq/data/append_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class AppendTokenDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | self.token = token 16 | if token is not None: 17 | self._sizes = np.array(dataset.sizes) + 1 18 | else: 19 | self._sizes = dataset.sizes 20 | 21 | def __getitem__(self, idx): 22 | item = self.dataset[idx] 23 | if self.token is not None: 24 | item = torch.cat([item, item.new([self.token])]) 25 | return item 26 | 27 | @property 28 | def sizes(self): 29 | return self._sizes 30 | 31 | def num_tokens(self, index): 32 | n = self.dataset.num_tokens(index) 33 | if self.token is not None: 34 | n += 1 35 | return n 36 | 37 | def size(self, index): 38 | n = self.dataset.size(index) 39 | if self.token is not None: 40 | n += 1 41 | return n 42 | -------------------------------------------------------------------------------- /codes/fairseq/data/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/data/audio/__init__.py -------------------------------------------------------------------------------- /codes/fairseq/data/audio/feature_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, Optional 5 | 6 | 7 | class AudioFeatureTransform(ABC): 8 | @classmethod 9 | @abstractmethod 10 | def from_config_dict(cls, config: Optional[Dict] = None): 11 | pass 12 | 13 | 14 | AUDIO_FEATURE_TRANSFORM_REGISTRY = {} 15 | AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set() 16 | 17 | 18 | def register_audio_feature_transform(name): 19 | def register_audio_feature_transform_cls(cls): 20 | if name in AUDIO_FEATURE_TRANSFORM_REGISTRY: 21 | raise ValueError(f"Cannot register duplicate transform ({name})") 22 | if not issubclass(cls, AudioFeatureTransform): 23 | raise ValueError( 24 | f"Transform ({name}: {cls.__name__}) must extend " 25 | "AudioFeatureTransform" 26 | ) 27 | if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES: 28 | raise ValueError( 29 | f"Cannot register audio feature transform with duplicate " 30 | f"class name ({cls.__name__})" 31 | ) 32 | AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls 33 | AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__) 34 | return cls 35 | 36 | return register_audio_feature_transform_cls 37 | 38 | 39 | def get_audio_feature_transform(name): 40 | return AUDIO_FEATURE_TRANSFORM_REGISTRY[name] 41 | 42 | 43 | transforms_dir = os.path.dirname(__file__) 44 | for file in os.listdir(transforms_dir): 45 | path = os.path.join(transforms_dir, file) 46 | if ( 47 | not file.startswith("_") 48 | and not file.startswith(".") 49 | and (file.endswith(".py") or os.path.isdir(path)) 50 | ): 51 | name = file[: file.find(".py")] if file.endswith(".py") else file 52 | importlib.import_module("fairseq.data.audio.feature_transforms." + name) 53 | 54 | 55 | class CompositeAudioFeatureTransform(AudioFeatureTransform): 56 | @classmethod 57 | def from_config_dict(cls, config=None): 58 | _config = {} if config is None else config 59 | _transforms = _config.get("transforms") 60 | if _transforms is None: 61 | return None 62 | transforms = [ 63 | get_audio_feature_transform(_t).from_config_dict(_config.get(_t)) 64 | for _t in _transforms 65 | ] 66 | return CompositeAudioFeatureTransform(transforms) 67 | 68 | def __init__(self, transforms): 69 | self.transforms = [t for t in transforms if t is not None] 70 | 71 | def __call__(self, x): 72 | for t in self.transforms: 73 | x = t(x) 74 | return x 75 | 76 | def __repr__(self): 77 | format_string = ( 78 | [self.__class__.__name__ + "("] 79 | + [f" {t.__repr__()}" for t in self.transforms] 80 | + [")"] 81 | ) 82 | return "\n".join(format_string) 83 | -------------------------------------------------------------------------------- /codes/fairseq/data/audio/feature_transforms/global_cmvn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fairseq.data.audio.feature_transforms import ( 3 | AudioFeatureTransform, 4 | register_audio_feature_transform, 5 | ) 6 | 7 | 8 | @register_audio_feature_transform("global_cmvn") 9 | class GlobalCMVN(AudioFeatureTransform): 10 | """Global CMVN (cepstral mean and variance normalization). The global mean 11 | and variance need to be pre-computed and stored in NumPy format (.npz).""" 12 | 13 | @classmethod 14 | def from_config_dict(cls, config=None): 15 | _config = {} if config is None else config 16 | return GlobalCMVN(_config.get("stats_npz_path")) 17 | 18 | def __init__(self, stats_npz_path): 19 | stats = np.load(stats_npz_path) 20 | self.mean, self.std = stats["mean"], stats["std"] 21 | 22 | def __call__(self, x): 23 | x = np.subtract(x, self.mean) 24 | x = np.divide(x, self.std) 25 | return x 26 | -------------------------------------------------------------------------------- /codes/fairseq/data/audio/feature_transforms/utterance_cmvn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fairseq.data.audio.feature_transforms import ( 3 | AudioFeatureTransform, 4 | register_audio_feature_transform, 5 | ) 6 | 7 | 8 | @register_audio_feature_transform("utterance_cmvn") 9 | class UtteranceCMVN(AudioFeatureTransform): 10 | """Utterance-level CMVN (cepstral mean and variance normalization)""" 11 | 12 | @classmethod 13 | def from_config_dict(cls, config=None): 14 | _config = {} if config is None else config 15 | return UtteranceCMVN( 16 | _config.get("norm_means", True), 17 | _config.get("norm_vars", True), 18 | ) 19 | 20 | def __init__(self, norm_means=True, norm_vars=True): 21 | self.norm_means, self.norm_vars = norm_means, norm_vars 22 | 23 | def __repr__(self): 24 | return ( 25 | self.__class__.__name__ 26 | + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})" 27 | ) 28 | 29 | def __call__(self, x): 30 | mean = x.mean(axis=0) 31 | square_sums = (x ** 2).sum(axis=0) 32 | 33 | if self.norm_means: 34 | x = np.subtract(x, mean) 35 | if self.norm_vars: 36 | var = square_sums / x.shape[0] - mean ** 2 37 | std = np.sqrt(np.maximum(var, 1e-10)) 38 | x = np.divide(x, std) 39 | 40 | return x 41 | -------------------------------------------------------------------------------- /codes/fairseq/data/base_wrapper_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class BaseWrapperDataset(FairseqDataset): 12 | def __init__(self, dataset): 13 | super().__init__() 14 | self.dataset = dataset 15 | 16 | def __getitem__(self, index): 17 | return self.dataset[index] 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def collater(self, samples): 23 | if hasattr(self.dataset, "collater"): 24 | return self.dataset.collater(samples) 25 | else: 26 | return default_collate(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return self.dataset.sizes 31 | 32 | def num_tokens(self, index): 33 | return self.dataset.num_tokens(index) 34 | 35 | def size(self, index): 36 | return self.dataset.size(index) 37 | 38 | def ordered_indices(self): 39 | return self.dataset.ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return getattr(self.dataset, "supports_prefetch", False) 44 | 45 | def attr(self, attr: str, index: int): 46 | return self.dataset.attr(attr, index) 47 | 48 | def prefetch(self, indices): 49 | self.dataset.prefetch(indices) 50 | 51 | def get_batch_shapes(self): 52 | return self.dataset.get_batch_shapes() 53 | 54 | def batch_by_size( 55 | self, 56 | indices, 57 | max_tokens=None, 58 | max_sentences=None, 59 | required_batch_size_multiple=1, 60 | ): 61 | return self.dataset.batch_by_size( 62 | indices, 63 | max_tokens=max_tokens, 64 | max_sentences=max_sentences, 65 | required_batch_size_multiple=required_batch_size_multiple, 66 | ) 67 | 68 | def filter_indices_by_size(self, indices, max_sizes): 69 | return self.dataset.filter_indices_by_size(indices, max_sizes) 70 | 71 | @property 72 | def can_reuse_epoch_itr_across_epochs(self): 73 | return self.dataset.can_reuse_epoch_itr_across_epochs 74 | 75 | def set_epoch(self, epoch): 76 | super().set_epoch(epoch) 77 | if hasattr(self.dataset, "set_epoch"): 78 | self.dataset.set_epoch(epoch) 79 | -------------------------------------------------------------------------------- /codes/fairseq/data/bucket_pad_length_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from fairseq.data import BaseWrapperDataset 9 | 10 | 11 | class BucketPadLengthDataset(BaseWrapperDataset): 12 | """ 13 | Bucket and pad item lengths to the nearest bucket size. This can be used to 14 | reduce the number of unique batch shapes, which is important on TPUs since 15 | each new batch shape requires a recompilation. 16 | 17 | Args: 18 | dataset (FairseqDatset): dataset to bucket 19 | sizes (List[int]): all item sizes 20 | num_buckets (int): number of buckets to create 21 | pad_idx (int): padding symbol 22 | left_pad (bool): if True, pad on the left; otherwise right pad 23 | """ 24 | 25 | def __init__( 26 | self, 27 | dataset, 28 | sizes, 29 | num_buckets, 30 | pad_idx, 31 | left_pad, 32 | ): 33 | super().__init__(dataset) 34 | self.pad_idx = pad_idx 35 | self.left_pad = left_pad 36 | 37 | assert num_buckets > 0 38 | self.buckets = np.unique( 39 | np.percentile( 40 | sizes, 41 | np.linspace(0, 100, num_buckets + 1), 42 | interpolation="lower", 43 | )[1:] 44 | ) 45 | 46 | def get_bucketed_sizes(orig_sizes, buckets): 47 | sizes = np.copy(orig_sizes) 48 | assert np.min(sizes) >= 0 49 | start_val = -1 50 | for end_val in buckets: 51 | mask = (sizes > start_val) & (sizes <= end_val) 52 | sizes[mask] = end_val 53 | start_val = end_val 54 | return sizes 55 | 56 | self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) 57 | 58 | def __getitem__(self, index): 59 | item = self.dataset[index] 60 | bucket_size = self._bucketed_sizes[index] 61 | num_pad = bucket_size - item.size(-1) 62 | return F.pad( 63 | item, 64 | (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), 65 | value=self.pad_idx, 66 | ) 67 | 68 | @property 69 | def sizes(self): 70 | return self._bucketed_sizes 71 | 72 | def num_tokens(self, index): 73 | return self._bucketed_sizes[index] 74 | 75 | def size(self, index): 76 | return self._bucketed_sizes[index] 77 | -------------------------------------------------------------------------------- /codes/fairseq/data/colorize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class ColorizeDataset(BaseWrapperDataset): 12 | """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """ 13 | 14 | def __init__(self, dataset, color_getter): 15 | super().__init__(dataset) 16 | self.color_getter = color_getter 17 | 18 | def collater(self, samples): 19 | base_collate = super().collater(samples) 20 | if len(base_collate) > 0: 21 | base_collate["net_input"]["colors"] = torch.tensor( 22 | list(self.color_getter(self.dataset, s["id"]) for s in samples), 23 | dtype=torch.long, 24 | ) 25 | return base_collate 26 | -------------------------------------------------------------------------------- /codes/fairseq/data/concat_sentences_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class ConcatSentencesDataset(FairseqDataset): 12 | def __init__(self, *datasets): 13 | super().__init__() 14 | self.datasets = datasets 15 | assert all( 16 | len(ds) == len(datasets[0]) for ds in datasets 17 | ), "datasets must have the same length" 18 | 19 | def __getitem__(self, index): 20 | return torch.cat([ds[index] for ds in self.datasets]) 21 | 22 | def __len__(self): 23 | return len(self.datasets[0]) 24 | 25 | def collater(self, samples): 26 | return self.datasets[0].collater(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return sum(ds.sizes for ds in self.datasets) 31 | 32 | def num_tokens(self, index): 33 | return sum(ds.num_tokens(index) for ds in self.datasets) 34 | 35 | def size(self, index): 36 | return sum(ds.size(index) for ds in self.datasets) 37 | 38 | def ordered_indices(self): 39 | return self.datasets[0].ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) 44 | 45 | def prefetch(self, indices): 46 | for ds in self.datasets: 47 | if getattr(ds, "supports_prefetch", False): 48 | ds.prefetch(indices) 49 | 50 | def set_epoch(self, epoch): 51 | super().set_epoch(epoch) 52 | for ds in self.datasets: 53 | if hasattr(ds, "set_epoch"): 54 | ds.set_epoch(epoch) 55 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( 14 | "--tokenizer", 15 | default=None, 16 | ) 17 | 18 | 19 | build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( 20 | "--bpe", 21 | default=None, 22 | ) 23 | 24 | 25 | # automatically import any Python files in the encoders/ directory 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith(".py") and not file.startswith("_"): 28 | module = file[: file.find(".py")] 29 | importlib.import_module("fairseq.data.encoders." + module) 30 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/byte_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq import file_utils 8 | from fairseq.data.encoders import register_bpe 9 | from fairseq.data.encoders.byte_utils import ( 10 | SPACE, 11 | SPACE_ESCAPE, 12 | byte_encode, 13 | smart_byte_decode, 14 | ) 15 | 16 | 17 | @register_bpe("byte_bpe") 18 | class ByteBPE(object): 19 | @staticmethod 20 | def add_args(parser): 21 | # fmt: off 22 | parser.add_argument('--sentencepiece-model-path', type=str, 23 | help='path to sentencepiece model') 24 | # fmt: on 25 | 26 | def __init__(self, args): 27 | vocab = file_utils.cached_path(args.sentencepiece_model_path) 28 | try: 29 | import sentencepiece as spm 30 | 31 | self.sp = spm.SentencePieceProcessor() 32 | self.sp.Load(vocab) 33 | except ImportError: 34 | raise ImportError( 35 | "Please install sentencepiece with: pip install sentencepiece" 36 | ) 37 | 38 | def encode(self, x: str) -> str: 39 | byte_encoded = byte_encode(x) 40 | return SPACE.join(self.sp.EncodeAsPieces(byte_encoded)) 41 | 42 | @staticmethod 43 | def decode(x: str) -> str: 44 | unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) 45 | return smart_byte_decode(unescaped) 46 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/byte_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | 9 | WHITESPACE_NORMALIZER = re.compile(r"\s+") 10 | SPACE = chr(32) 11 | SPACE_ESCAPE = chr(9601) 12 | # excluding non-breaking space (160) here 13 | PRINTABLE_LATIN = set( 14 | list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1)) 15 | ) 16 | BYTE_TO_BCHAR = { 17 | b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256) 18 | } 19 | BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} 20 | 21 | 22 | def byte_encode(x: str) -> str: 23 | normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) 24 | return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) 25 | 26 | 27 | def byte_decode(x: str) -> str: 28 | try: 29 | return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") 30 | except ValueError: 31 | return "" 32 | 33 | 34 | def smart_byte_decode(x: str) -> str: 35 | output = byte_decode(x) 36 | if output == "": 37 | # DP the best recovery (max valid chars) if it's broken 38 | n_bytes = len(x) 39 | f = [0 for _ in range(n_bytes + 1)] 40 | pt = [0 for _ in range(n_bytes + 1)] 41 | for i in range(1, n_bytes + 1): 42 | f[i], pt[i] = f[i - 1], i - 1 43 | for j in range(1, min(4, i) + 1): 44 | if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: 45 | f[i], pt[i] = f[i - j] + 1, i - j 46 | cur_pt = n_bytes 47 | while cur_pt > 0: 48 | if f[cur_pt] == f[pt[cur_pt]] + 1: 49 | output = byte_decode(x[pt[cur_pt] : cur_pt]) + output 50 | cur_pt = pt[cur_pt] 51 | return output 52 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/bytes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq.data.encoders import register_bpe 8 | from fairseq.data.encoders.byte_utils import ( 9 | SPACE, 10 | SPACE_ESCAPE, 11 | byte_encode, 12 | smart_byte_decode, 13 | ) 14 | 15 | 16 | @register_bpe("bytes") 17 | class Bytes(object): 18 | def __init__(self, args): 19 | pass 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | pass 24 | 25 | @staticmethod 26 | def encode(x: str) -> str: 27 | encoded = byte_encode(x) 28 | escaped = encoded.replace(SPACE, SPACE_ESCAPE) 29 | return SPACE.join(list(escaped)) 30 | 31 | @staticmethod 32 | def decode(x: str) -> str: 33 | unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) 34 | return smart_byte_decode(unescaped) 35 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/characters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | SPACE = chr(32) 11 | SPACE_ESCAPE = chr(9601) 12 | 13 | 14 | @register_bpe("characters") 15 | class Characters(object): 16 | def __init__(self, args): 17 | pass 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | pass 22 | 23 | @staticmethod 24 | def encode(x: str) -> str: 25 | escaped = x.replace(SPACE, SPACE_ESCAPE) 26 | return SPACE.join(list(escaped)) 27 | 28 | @staticmethod 29 | def decode(x: str) -> str: 30 | return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) 31 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/fastbpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe("fastbpe") 11 | class fastBPE(object): 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--bpe-codes', type=str, 16 | help='path to fastBPE BPE') 17 | # fmt: on 18 | 19 | def __init__(self, args): 20 | if args.bpe_codes is None: 21 | raise ValueError("--bpe-codes is required for --bpe=fastbpe") 22 | codes = file_utils.cached_path(args.bpe_codes) 23 | try: 24 | import fastBPE 25 | 26 | self.bpe = fastBPE.fastBPE(codes) 27 | self.bpe_symbol = "@@ " 28 | except ImportError: 29 | raise ImportError("Please install fastBPE with: pip install fastBPE") 30 | 31 | def encode(self, x: str) -> str: 32 | return self.bpe.apply([x])[0] 33 | 34 | def decode(self, x: str) -> str: 35 | return (x + " ").replace(self.bpe_symbol, "").rstrip() 36 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/gpt2_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | from .gpt2_bpe_utils import get_encoder 10 | 11 | 12 | DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" 13 | DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" 14 | 15 | 16 | @register_bpe("gpt2") 17 | class GPT2BPE(object): 18 | @staticmethod 19 | def add_args(parser): 20 | # fmt: off 21 | parser.add_argument('--gpt2-encoder-json', type=str, 22 | default=DEFAULT_ENCODER_JSON, 23 | help='path to encoder.json') 24 | parser.add_argument('--gpt2-vocab-bpe', type=str, 25 | default=DEFAULT_VOCAB_BPE, 26 | help='path to vocab.bpe') 27 | # fmt: on 28 | 29 | def __init__(self, args): 30 | encoder_json = file_utils.cached_path( 31 | getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON) 32 | ) 33 | vocab_bpe = file_utils.cached_path( 34 | getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE) 35 | ) 36 | self.bpe = get_encoder(encoder_json, vocab_bpe) 37 | 38 | def encode(self, x: str) -> str: 39 | return " ".join(map(str, self.bpe.encode(x))) 40 | 41 | def decode(self, x: str) -> str: 42 | return self.bpe.decode( 43 | [int(tok) if tok not in {"", ""} else tok for tok in x.split()] 44 | ) 45 | 46 | def is_beginning_of_word(self, x: str) -> bool: 47 | return self.decode(x).startswith(" ") 48 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/hf_bert_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_bpe 7 | 8 | 9 | @register_bpe("bert") 10 | class BertBPE(object): 11 | @staticmethod 12 | def add_args(parser): 13 | # fmt: off 14 | parser.add_argument('--bpe-cased', action='store_true', 15 | help='set for cased BPE', 16 | default=False) 17 | parser.add_argument('--bpe-vocab-file', type=str, 18 | help='bpe vocab file.') 19 | # fmt: on 20 | 21 | def __init__(self, args): 22 | try: 23 | from transformers import BertTokenizer 24 | except ImportError: 25 | raise ImportError( 26 | "Please install transformers with: pip install transformers" 27 | ) 28 | 29 | if "bpe_vocab_file" in args: 30 | self.bert_tokenizer = BertTokenizer( 31 | args.bpe_vocab_file, do_lower_case=not args.bpe_cased 32 | ) 33 | else: 34 | vocab_file_name = ( 35 | "bert-base-cased" if args.bpe_cased else "bert-base-uncased" 36 | ) 37 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) 38 | 39 | def encode(self, x: str) -> str: 40 | return " ".join(self.bert_tokenizer.tokenize(x)) 41 | 42 | def decode(self, x: str) -> str: 43 | return self.bert_tokenizer.clean_up_tokenization( 44 | self.bert_tokenizer.convert_tokens_to_string(x.split(" ")) 45 | ) 46 | 47 | def is_beginning_of_word(self, x: str) -> bool: 48 | return not x.startswith("##") 49 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/hf_byte_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_bpe 7 | 8 | 9 | @register_bpe("hf_byte_bpe") 10 | class HuggingFaceByteLevelBPE(object): 11 | @staticmethod 12 | def add_args(parser): 13 | # fmt: off 14 | parser.add_argument('--bpe-merges', help='path to merges.txt') 15 | parser.add_argument('--bpe-vocab', help='path to vocab.json') 16 | parser.add_argument('--bpe-add-prefix-space', action='store_true', 17 | help='add prefix space before encoding') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | try: 22 | from tokenizers import ByteLevelBPETokenizer 23 | except ImportError: 24 | raise ImportError( 25 | "Please install huggingface/tokenizers with: " "pip install tokenizers" 26 | ) 27 | 28 | self.bpe = ByteLevelBPETokenizer( 29 | args.bpe_vocab, 30 | args.bpe_merges, 31 | add_prefix_space=getattr(args, "bpe_add_prefix_space", False), 32 | ) 33 | 34 | def encode(self, x: str) -> str: 35 | return " ".join(map(str, self.bpe.encode(x).ids)) 36 | 37 | def decode(self, x: str) -> str: 38 | return self.bpe.decode( 39 | [int(tok) if tok not in {"", ""} else tok for tok in x.split()] 40 | ) 41 | 42 | def is_beginning_of_word(self, x: str) -> bool: 43 | return self.decode(x).startswith(" ") 44 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer("moses") 10 | class MosesTokenizer(object): 11 | @staticmethod 12 | def add_args(parser): 13 | # fmt: off 14 | parser.add_argument('--moses-source-lang', metavar='SRC', 15 | help='source language') 16 | parser.add_argument('--moses-target-lang', metavar='TARGET', 17 | help='target language') 18 | parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, 19 | help='don\'t apply dash split rules') 20 | parser.add_argument('--moses-no-escape', action='store_true', default=False, 21 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.') 22 | # fmt: on 23 | 24 | def __init__(self, args): 25 | self.args = args 26 | 27 | if getattr(args, "moses_source_lang", None) is None: 28 | args.moses_source_lang = getattr(args, "source_lang", "en") 29 | if getattr(args, "moses_target_lang", None) is None: 30 | args.moses_target_lang = getattr(args, "target_lang", "en") 31 | 32 | try: 33 | from sacremoses import MosesTokenizer, MosesDetokenizer 34 | 35 | self.tok = MosesTokenizer(args.moses_source_lang) 36 | self.detok = MosesDetokenizer(args.moses_target_lang) 37 | except ImportError: 38 | raise ImportError( 39 | "Please install Moses tokenizer with: pip install sacremoses" 40 | ) 41 | 42 | def encode(self, x: str) -> str: 43 | return self.tok.tokenize( 44 | x, 45 | aggressive_dash_splits=(not self.args.moses_no_dash_splits), 46 | return_str=True, 47 | escape=(not self.args.moses_no_escape), 48 | ) 49 | 50 | def decode(self, x: str) -> str: 51 | return self.detok.detokenize(x.split()) 52 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer("nltk") 10 | class NLTKTokenizer(object): 11 | def __init__(self, source_lang=None, target_lang=None): 12 | try: 13 | from nltk.tokenize import word_tokenize 14 | 15 | self.word_tokenize = word_tokenize 16 | except ImportError: 17 | raise ImportError("Please install nltk with: pip install nltk") 18 | 19 | def encode(self, x: str) -> str: 20 | return " ".join(self.word_tokenize(x)) 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe("sentencepiece") 11 | class SentencepieceBPE(object): 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--sentencepiece-model', type=str, 16 | help='path to sentencepiece model') 17 | # fmt: on 18 | 19 | def __init__(self, args): 20 | sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) 21 | try: 22 | import sentencepiece as spm 23 | 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(sentencepiece_model) 26 | except ImportError: 27 | raise ImportError( 28 | "Please install sentencepiece with: pip install sentencepiece" 29 | ) 30 | 31 | def encode(self, x: str) -> str: 32 | return " ".join(self.sp.EncodeAsPieces(x)) 33 | 34 | def decode(self, x: str) -> str: 35 | return x.replace(" ", "").replace("\u2581", " ").strip() 36 | 37 | def is_beginning_of_word(self, x: str) -> bool: 38 | if x in ["", "", "", ""]: 39 | # special elements are always considered beginnings 40 | # HACK: this logic is already present in fairseq/tasks/masked_lm.py 41 | # but these special tokens are also contained in the sentencepiece 42 | # vocabulary which causes duplicate special tokens. This hack makes 43 | # sure that they are all taken into account. 44 | return True 45 | return x.startswith("\u2581") 46 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | 10 | 11 | @register_tokenizer("space") 12 | class SpaceTokenizer(object): 13 | def __init__(self, source_lang=None, target_lang=None): 14 | self.space_tok = re.compile(r"\s+") 15 | 16 | def encode(self, x: str) -> str: 17 | return self.space_tok.sub(" ", x) 18 | 19 | def decode(self, x: str) -> str: 20 | return x 21 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/subword_nmt_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe("subword_nmt") 11 | class SubwordNMTBPE(object): 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--bpe-codes', type=str, 16 | help='path to subword NMT BPE') 17 | parser.add_argument('--bpe-separator', default='@@', 18 | help='BPE separator') 19 | # fmt: on 20 | 21 | def __init__(self, args): 22 | if args.bpe_codes is None: 23 | raise ValueError("--bpe-codes is required for --bpe=subword_nmt") 24 | codes = file_utils.cached_path(args.bpe_codes) 25 | try: 26 | from subword_nmt import apply_bpe 27 | 28 | bpe_parser = apply_bpe.create_parser() 29 | bpe_args = bpe_parser.parse_args( 30 | [ 31 | "--codes", 32 | codes, 33 | "--separator", 34 | args.bpe_separator, 35 | ] 36 | ) 37 | self.bpe = apply_bpe.BPE( 38 | bpe_args.codes, 39 | bpe_args.merges, 40 | bpe_args.separator, 41 | None, 42 | bpe_args.glossaries, 43 | ) 44 | self.bpe_symbol = bpe_args.separator + " " 45 | except ImportError: 46 | raise ImportError( 47 | "Please install subword_nmt with: pip install subword-nmt" 48 | ) 49 | 50 | def encode(self, x: str) -> str: 51 | return self.bpe.process_line(x) 52 | 53 | def decode(self, x: str) -> str: 54 | return (x + " ").replace(self.bpe_symbol, "").rstrip() 55 | -------------------------------------------------------------------------------- /codes/fairseq/data/encoders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from fairseq.data import encoders 8 | 9 | 10 | def get_whole_word_mask(args, dictionary): 11 | bpe = encoders.build_bpe(args) 12 | if bpe is not None: 13 | 14 | def is_beginning_of_word(i): 15 | if i < dictionary.nspecial: 16 | # special elements are always considered beginnings 17 | return True 18 | tok = dictionary[i] 19 | if tok.startswith("madeupword"): 20 | return True 21 | try: 22 | return bpe.is_beginning_of_word(tok) 23 | except ValueError: 24 | return True 25 | 26 | mask_whole_words = torch.ByteTensor( 27 | list(map(is_beginning_of_word, range(len(dictionary)))) 28 | ) 29 | return mask_whole_words 30 | return None 31 | -------------------------------------------------------------------------------- /codes/fairseq/data/id_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class IdDataset(FairseqDataset): 12 | def __getitem__(self, index): 13 | return index 14 | 15 | def __len__(self): 16 | return 0 17 | 18 | def collater(self, samples): 19 | return torch.tensor(samples) 20 | -------------------------------------------------------------------------------- /codes/fairseq/data/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .block_pair_dataset import BlockPairDataset 7 | from .masked_lm_dataset import MaskedLMDataset 8 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary 9 | 10 | 11 | __all__ = [ 12 | "BertDictionary", 13 | "BlockPairDataset", 14 | "MaskedLMDataset", 15 | "MaskedLMDictionary", 16 | ] 17 | -------------------------------------------------------------------------------- /codes/fairseq/data/legacy/masked_lm_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import Dictionary 7 | 8 | 9 | class MaskedLMDictionary(Dictionary): 10 | """ 11 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by 12 | adding the mask symbol. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | pad="", 18 | eos="", 19 | unk="", 20 | mask="", 21 | ): 22 | super().__init__(pad=pad, eos=eos, unk=unk) 23 | self.mask_word = mask 24 | self.mask_index = self.add_symbol(mask) 25 | self.nspecial = len(self.symbols) 26 | 27 | def mask(self): 28 | """Helper to get index of mask symbol""" 29 | return self.mask_index 30 | 31 | 32 | class BertDictionary(MaskedLMDictionary): 33 | """ 34 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support 35 | for cls and sep symbols. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | pad="", 41 | eos="", 42 | unk="", 43 | mask="", 44 | cls="", 45 | sep="", 46 | ): 47 | super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) 48 | self.cls_word = cls 49 | self.sep_word = sep 50 | self.cls_index = self.add_symbol(cls) 51 | self.sep_index = self.add_symbol(sep) 52 | self.nspecial = len(self.symbols) 53 | 54 | def cls(self): 55 | """Helper to get index of cls symbol""" 56 | return self.cls_index 57 | 58 | def sep(self): 59 | """Helper to get index of sep symbol""" 60 | return self.sep_index 61 | -------------------------------------------------------------------------------- /codes/fairseq/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ListDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, sizes=None): 11 | super().__init__(dataset) 12 | self._sizes = sizes 13 | 14 | def __iter__(self): 15 | for x in self.dataset: 16 | yield x 17 | 18 | def collater(self, samples): 19 | return samples 20 | 21 | @property 22 | def sizes(self): 23 | return self._sizes 24 | 25 | def num_tokens(self, index): 26 | return self.sizes[index] 27 | 28 | def size(self, index): 29 | return self.sizes[index] 30 | 31 | def set_epoch(self, epoch): 32 | pass 33 | -------------------------------------------------------------------------------- /codes/fairseq/data/lru_cache_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import lru_cache 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class LRUCacheDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, token=None): 13 | super().__init__(dataset) 14 | 15 | @lru_cache(maxsize=8) 16 | def __getitem__(self, index): 17 | return self.dataset[index] 18 | 19 | @lru_cache(maxsize=8) 20 | def collater(self, samples): 21 | return self.dataset.collater(samples) 22 | -------------------------------------------------------------------------------- /codes/fairseq/data/multilingual/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /codes/fairseq/data/multilingual/multilingual_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict, List, Optional, Sequence 3 | 4 | import torch 5 | from fairseq.data import Dictionary 6 | 7 | 8 | class EncoderLangtok(Enum): 9 | """ 10 | Prepend to the beginning of source sentence either the 11 | source or target language token. (src/tgt). 12 | """ 13 | 14 | src = "src" 15 | tgt = "tgt" 16 | 17 | 18 | class LangTokSpec(Enum): 19 | main = "main" 20 | mono_dae = "mono_dae" 21 | 22 | 23 | class LangTokStyle(Enum): 24 | multilingual = "multilingual" 25 | mbart = "mbart" 26 | 27 | 28 | @torch.jit.export 29 | def get_lang_tok( 30 | lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value 31 | ) -> str: 32 | # TOKEN_STYLES can't be defined outside this fn since it needs to be 33 | # TorchScriptable. 34 | TOKEN_STYLES: Dict[str, str] = { 35 | LangTokStyle.mbart.value: "[{}]", 36 | LangTokStyle.multilingual.value: "__{}__", 37 | } 38 | 39 | if spec.endswith("dae"): 40 | lang = f"{lang}_dae" 41 | elif spec.endswith("mined"): 42 | lang = f"{lang}_mined" 43 | style = TOKEN_STYLES[lang_tok_style] 44 | return style.format(lang) 45 | 46 | 47 | def augment_dictionary( 48 | dictionary: Dictionary, 49 | language_list: List[str], 50 | lang_tok_style: str, 51 | langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), 52 | extra_data: Optional[Dict[str, str]] = None, 53 | ) -> None: 54 | for spec in langtoks_specs: 55 | for language in language_list: 56 | dictionary.add_symbol( 57 | get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) 58 | ) 59 | 60 | if lang_tok_style == LangTokStyle.mbart.value or ( 61 | extra_data is not None and LangTokSpec.mono_dae.value in extra_data 62 | ): 63 | dictionary.add_symbol("") 64 | -------------------------------------------------------------------------------- /codes/fairseq/data/multilingual/sampling_method.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def uniform(dataset_sizes: List[int]): 14 | return [1.0] * len(dataset_sizes) 15 | 16 | 17 | def temperature_sampling(dataset_sizes, temp): 18 | total_size = sum(dataset_sizes) 19 | return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] 20 | 21 | 22 | def make_temperature_sampling(temp=1.0): 23 | def sampling_func(dataset_sizes): 24 | return temperature_sampling(dataset_sizes, temp) 25 | 26 | return sampling_func 27 | 28 | 29 | def make_ratio_sampling(ratios): 30 | def sampling_func(dataset_sizes): 31 | return ratios 32 | 33 | return sampling_func 34 | 35 | 36 | class SamplingMethod: 37 | @staticmethod 38 | def add_arguments(parser): 39 | parser.add_argument( 40 | "--sampling-method", 41 | choices=[ 42 | "uniform", 43 | "temperature", 44 | "concat", 45 | "RoundRobin", 46 | ], 47 | type=str, 48 | default="concat", 49 | help="The method to sample data per language pairs", 50 | ) 51 | parser.add_argument( 52 | "--sampling-temperature", 53 | default=1.5, 54 | type=float, 55 | help="only work with --sampling-method temperature", 56 | ) 57 | 58 | @staticmethod 59 | def build_sampler(args, task): 60 | return SamplingMethod(args, task) 61 | 62 | def __init__(self, args, task): 63 | self.args = args 64 | self.task = task 65 | 66 | def is_adaptive(self): 67 | return False 68 | 69 | def sampling_method_selector(self): 70 | args = self.args 71 | logger.info(f"selected sampler: {args.sampling_method}") 72 | if args.sampling_method == "uniform": 73 | return uniform 74 | elif args.sampling_method == "temperature" or self.is_adaptive(): 75 | return make_temperature_sampling(float(args.sampling_temperature)) 76 | else: 77 | # default to concating all data set together 78 | return None 79 | -------------------------------------------------------------------------------- /codes/fairseq/data/num_samples_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqDataset 7 | 8 | 9 | class NumSamplesDataset(FairseqDataset): 10 | def __getitem__(self, index): 11 | return 1 12 | 13 | def __len__(self): 14 | return 0 15 | 16 | def collater(self, samples): 17 | return sum(samples) 18 | -------------------------------------------------------------------------------- /codes/fairseq/data/numel_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class NumelDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, reduce=False): 14 | super().__init__(dataset) 15 | self.reduce = reduce 16 | 17 | def __getitem__(self, index): 18 | item = self.dataset[index] 19 | if torch.is_tensor(item): 20 | return torch.numel(item) 21 | else: 22 | return np.size(item) 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def collater(self, samples): 28 | if self.reduce: 29 | return sum(samples) 30 | else: 31 | return torch.tensor(samples) 32 | -------------------------------------------------------------------------------- /codes/fairseq/data/offset_tokens_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class OffsetTokensDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, offset): 11 | super().__init__(dataset) 12 | self.offset = offset 13 | 14 | def __getitem__(self, idx): 15 | return self.dataset[idx] + self.offset 16 | -------------------------------------------------------------------------------- /codes/fairseq/data/pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import data_utils 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class PadDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, pad_idx, left_pad): 13 | super().__init__(dataset) 14 | self.pad_idx = pad_idx 15 | self.left_pad = left_pad 16 | 17 | def collater(self, samples): 18 | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) 19 | 20 | 21 | class LeftPadDataset(PadDataset): 22 | def __init__(self, dataset, pad_idx): 23 | super().__init__(dataset, pad_idx, left_pad=True) 24 | 25 | 26 | class RightPadDataset(PadDataset): 27 | def __init__(self, dataset, pad_idx): 28 | super().__init__(dataset, pad_idx, left_pad=False) 29 | -------------------------------------------------------------------------------- /codes/fairseq/data/plasma_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import subprocess 7 | import tempfile 8 | 9 | 10 | class PlasmaArray(object): 11 | """ 12 | Wrapper around numpy arrays that automatically moves the data to shared 13 | memory upon serialization. This is particularly helpful when passing numpy 14 | arrays through multiprocessing, so that data is not unnecessarily 15 | duplicated or pickled. 16 | """ 17 | 18 | def __init__(self, array): 19 | super().__init__() 20 | self.array = array 21 | self.disable = array.nbytes < 134217728 # disable for arrays <128MB 22 | self.object_id = None 23 | self.path = None 24 | 25 | # variables with underscores shouldn't be pickled 26 | self._client = None 27 | self._server = None 28 | self._server_tmp = None 29 | self._plasma = None 30 | 31 | @property 32 | def plasma(self): 33 | if self._plasma is None and not self.disable: 34 | try: 35 | import pyarrow.plasma as plasma 36 | 37 | self._plasma = plasma 38 | except ImportError: 39 | self._plasma = None 40 | return self._plasma 41 | 42 | def start_server(self): 43 | if self.plasma is None or self._server is not None: 44 | return 45 | assert self.object_id is None 46 | assert self.path is None 47 | self._server_tmp = tempfile.NamedTemporaryFile() 48 | self.path = self._server_tmp.name 49 | self._server = subprocess.Popen( 50 | [ 51 | "plasma_store", 52 | "-m", 53 | str(int(1.05 * self.array.nbytes)), 54 | "-s", 55 | self.path, 56 | ] 57 | ) 58 | 59 | @property 60 | def client(self): 61 | if self._client is None: 62 | assert self.path is not None 63 | self._client = self.plasma.connect(self.path) 64 | return self._client 65 | 66 | def __getstate__(self): 67 | if self.plasma is None: 68 | return self.__dict__ 69 | if self.object_id is None: 70 | self.start_server() 71 | self.object_id = self.client.put(self.array) 72 | state = self.__dict__.copy() 73 | del state["array"] 74 | state["_client"] = None 75 | state["_server"] = None 76 | state["_server_tmp"] = None 77 | state["_plasma"] = None 78 | return state 79 | 80 | def __setstate__(self, state): 81 | self.__dict__.update(state) 82 | if self.plasma is None: 83 | return 84 | self.array = self.client.get(self.object_id) 85 | 86 | def __del__(self): 87 | if self._server is not None: 88 | self._server.kill() 89 | self._server = None 90 | self._server_tmp.close() 91 | self._server_tmp = None 92 | -------------------------------------------------------------------------------- /codes/fairseq/data/prepend_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): 14 | super().__init__(dataset) 15 | self.prepend_getter = prepend_getter 16 | self.ensure_first_token = ensure_first_token_is 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataset[idx] 20 | is_tuple = isinstance(item, tuple) 21 | src = item[0] if is_tuple else item 22 | 23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token 24 | prepend_idx = self.prepend_getter(self.dataset, idx) 25 | assert isinstance(prepend_idx, int) 26 | src[0] = prepend_idx 27 | item = tuple((src,) + item[1:]) if is_tuple else src 28 | return item 29 | -------------------------------------------------------------------------------- /codes/fairseq/data/prepend_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependTokenDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | self.token = token 16 | if token is not None: 17 | self._sizes = np.array(dataset.sizes) + 1 18 | else: 19 | self._sizes = dataset.sizes 20 | 21 | def __getitem__(self, idx): 22 | item = self.dataset[idx] 23 | if self.token is not None: 24 | item = torch.cat([item.new([self.token]), item]) 25 | return item 26 | 27 | @property 28 | def sizes(self): 29 | return self._sizes 30 | 31 | def num_tokens(self, index): 32 | n = self.dataset.num_tokens(index) 33 | if self.token is not None: 34 | n += 1 35 | return n 36 | 37 | def size(self, index): 38 | n = self.dataset.size(index) 39 | if self.token is not None: 40 | n += 1 41 | return n 42 | -------------------------------------------------------------------------------- /codes/fairseq/data/raw_label_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class RawLabelDataset(FairseqDataset): 12 | def __init__(self, labels): 13 | super().__init__() 14 | self.labels = labels 15 | 16 | def __getitem__(self, index): 17 | return self.labels[index] 18 | 19 | def __len__(self): 20 | return len(self.labels) 21 | 22 | def collater(self, samples): 23 | return torch.tensor(samples) 24 | -------------------------------------------------------------------------------- /codes/fairseq/data/replace_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ReplaceDataset(BaseWrapperDataset): 10 | """Replaces tokens found in the dataset by a specified replacement token 11 | 12 | Args: 13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in 14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token 15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be 16 | as many as the number of objects returned by the underlying dataset __getitem__ method. 17 | """ 18 | 19 | def __init__(self, dataset, replace_map, offsets): 20 | super().__init__(dataset) 21 | assert len(replace_map) > 0 22 | self.replace_map = replace_map 23 | self.offsets = offsets 24 | 25 | def __getitem__(self, index): 26 | item = self.dataset[index] 27 | is_tuple = isinstance(item, tuple) 28 | srcs = item if is_tuple else [item] 29 | 30 | for offset, src in zip(self.offsets, srcs): 31 | for k, v in self.replace_map.items(): 32 | src_off = src[offset:] if offset >= 0 else src[:offset] 33 | src_off.masked_fill_(src_off == k, v) 34 | 35 | item = srcs if is_tuple else srcs[0] 36 | return item 37 | -------------------------------------------------------------------------------- /codes/fairseq/data/roll_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class RollDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, shifts): 13 | super().__init__(dataset) 14 | self.shifts = shifts 15 | 16 | def __getitem__(self, index): 17 | item = self.dataset[index] 18 | return torch.roll(item, self.shifts) 19 | -------------------------------------------------------------------------------- /codes/fairseq/data/shorten_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | from fairseq.data import data_utils 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class TruncateDataset(BaseWrapperDataset): 13 | """Truncate a sequence by returning the first truncation_length tokens""" 14 | 15 | def __init__(self, dataset, truncation_length): 16 | super().__init__(dataset) 17 | assert truncation_length is not None 18 | self.truncation_length = truncation_length 19 | self.dataset = dataset 20 | 21 | def __getitem__(self, index): 22 | item = self.dataset[index] 23 | item_len = item.size(0) 24 | if item_len > self.truncation_length: 25 | item = item[: self.truncation_length] 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return np.minimum(self.dataset.sizes, self.truncation_length) 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | 36 | class RandomCropDataset(TruncateDataset): 37 | """Truncate a sequence by returning a random crop of truncation_length tokens""" 38 | 39 | def __init__(self, dataset, truncation_length, seed=1): 40 | super().__init__(dataset, truncation_length) 41 | self.seed = seed 42 | self.epoch = 0 43 | 44 | @property 45 | def can_reuse_epoch_itr_across_epochs(self): 46 | return True # only the crop changes, not item sizes 47 | 48 | def set_epoch(self, epoch, **unused): 49 | super().set_epoch(epoch) 50 | self.epoch = epoch 51 | 52 | def __getitem__(self, index): 53 | with data_utils.numpy_seed(self.seed, self.epoch, index): 54 | item = self.dataset[index] 55 | item_len = item.size(0) 56 | excess = item_len - self.truncation_length 57 | if excess > 0: 58 | start_idx = np.random.randint(0, excess) 59 | item = item[start_idx : start_idx + self.truncation_length] 60 | return item 61 | 62 | 63 | def maybe_shorten_dataset( 64 | dataset, 65 | split, 66 | shorten_data_split_list, 67 | shorten_method, 68 | tokens_per_sample, 69 | seed, 70 | ): 71 | truncate_split = ( 72 | split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0 73 | ) 74 | if shorten_method == "truncate" and truncate_split: 75 | dataset = TruncateDataset(dataset, tokens_per_sample) 76 | elif shorten_method == "random_crop" and truncate_split: 77 | dataset = RandomCropDataset(dataset, tokens_per_sample, seed) 78 | return dataset 79 | -------------------------------------------------------------------------------- /codes/fairseq/data/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class SortDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, sort_order): 13 | super().__init__(dataset) 14 | if not isinstance(sort_order, (list, tuple)): 15 | sort_order = [sort_order] 16 | self.sort_order = sort_order 17 | 18 | assert all(len(so) == len(dataset) for so in sort_order) 19 | 20 | def ordered_indices(self): 21 | return np.lexsort(self.sort_order) 22 | -------------------------------------------------------------------------------- /codes/fairseq/data/strip_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class StripTokenDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, id_to_strip): 11 | super().__init__(dataset) 12 | self.id_to_strip = id_to_strip 13 | 14 | def __getitem__(self, index): 15 | item = self.dataset[index] 16 | while len(item) > 0 and item[-1] == self.id_to_strip: 17 | item = item[:-1] 18 | while len(item) > 0 and item[0] == self.id_to_strip: 19 | item = item[1:] 20 | return item 21 | -------------------------------------------------------------------------------- /codes/fairseq/data/subsample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | import numpy as np 9 | 10 | from . import BaseWrapperDataset 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class SubsampleDataset(BaseWrapperDataset): 17 | """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples 18 | 19 | Args: 20 | dataset (~torch.utils.data.Dataset): dataset to subsample 21 | size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) 22 | """ 23 | 24 | def __init__(self, dataset, size_ratio, shuffle=False): 25 | super().__init__(dataset) 26 | assert size_ratio < 1 27 | self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) 28 | self.indices = np.random.choice( 29 | list(range(len(self.dataset))), self.actual_size, replace=False 30 | ) 31 | self.shuffle = shuffle 32 | logger.info( 33 | "subsampled dataset from {} to {} (ratio={})".format( 34 | len(self.dataset), self.actual_size, size_ratio 35 | ) 36 | ) 37 | 38 | def __getitem__(self, index): 39 | return self.dataset[self.indices[index]] 40 | 41 | def __len__(self): 42 | return self.actual_size 43 | 44 | def collater(self, samples): 45 | return self.dataset.collater(samples) 46 | 47 | @property 48 | def sizes(self): 49 | return self.dataset.sizes[self.indices] 50 | 51 | @property 52 | def name(self): 53 | return self.dataset.name 54 | 55 | def num_tokens(self, index): 56 | return self.dataset.num_tokens(self.indices[index]) 57 | 58 | def size(self, index): 59 | return self.dataset.size(self.indices[index]) 60 | 61 | def ordered_indices(self): 62 | """Return an ordered list of indices. Batches will be constructed based 63 | on this order.""" 64 | if self.shuffle: 65 | order = [np.random.permutation(len(self))] 66 | else: 67 | order = [np.arange(len(self))] 68 | order.append(self.sizes) 69 | return np.lexsort(order) 70 | 71 | def prefetch(self, indices): 72 | self.dataset.prefetch(self.indices[indices]) 73 | -------------------------------------------------------------------------------- /codes/fairseq/data_utils_fast.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/data_utils_fast.cpython-36m-darwin.so -------------------------------------------------------------------------------- /codes/fairseq/data_utils_fast.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/data_utils_fast.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /codes/fairseq/data_utils_fast.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/data_utils_fast.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /codes/fairseq/dataclass/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import ChoiceEnum, FairseqDataclass 7 | 8 | 9 | __all__ = ["FairseqDataclass", "ChoiceEnum"] 10 | -------------------------------------------------------------------------------- /codes/fairseq/dataclass/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.dataclass.utils import ChoiceEnum 7 | 8 | 9 | LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) 10 | DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) 11 | DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) 12 | ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) 13 | PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) 14 | -------------------------------------------------------------------------------- /codes/fairseq/incremental_decoding_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import uuid 7 | from typing import Dict, Optional 8 | 9 | from torch import Tensor 10 | 11 | 12 | class FairseqIncrementalState(object): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.init_incremental_state() 16 | 17 | def init_incremental_state(self): 18 | self._incremental_state_id = str(uuid.uuid4()) 19 | 20 | def _get_full_incremental_state_key(self, key: str) -> str: 21 | return "{}.{}".format(self._incremental_state_id, key) 22 | 23 | def get_incremental_state( 24 | self, 25 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 26 | key: str, 27 | ) -> Optional[Dict[str, Optional[Tensor]]]: 28 | """Helper for getting incremental state for an nn.Module.""" 29 | full_key = self._get_full_incremental_state_key(key) 30 | if incremental_state is None or full_key not in incremental_state: 31 | return None 32 | return incremental_state[full_key] 33 | 34 | def set_incremental_state( 35 | self, 36 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 37 | key: str, 38 | value: Dict[str, Optional[Tensor]], 39 | ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: 40 | """Helper for setting incremental state for an nn.Module.""" 41 | if incremental_state is not None: 42 | full_key = self._get_full_incremental_state_key(key) 43 | incremental_state[full_key] = value 44 | return incremental_state 45 | 46 | 47 | def with_incremental_state(cls): 48 | cls.__bases__ = (FairseqIncrementalState,) + tuple( 49 | b for b in cls.__bases__ if b != FairseqIncrementalState 50 | ) 51 | return cls 52 | -------------------------------------------------------------------------------- /codes/fairseq/logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/logging/__init__.py -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import criterions, models, modules # noqa 7 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the criterions/ directory 11 | for file in os.listdir(os.path.dirname(__file__)): 12 | if file.endswith(".py") and not file.startswith("_"): 13 | module = file[: file.find(".py")] 14 | importlib.import_module("fairseq.model_parallel.criterions." + module) 15 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/megatron_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Train a network across multiple GPUs. 8 | """ 9 | 10 | from fairseq import distributed_utils 11 | from fairseq.trainer import Trainer 12 | 13 | 14 | try: 15 | from fairseq.model_parallel.megatron.mpu import ( 16 | get_data_parallel_group, 17 | get_data_parallel_rank, 18 | get_data_parallel_world_size, 19 | get_model_parallel_group, 20 | get_model_parallel_src_rank, 21 | ) 22 | 23 | has_megatron_submodule = True 24 | except (ImportError, ModuleNotFoundError): 25 | has_megatron_submodule = False 26 | 27 | 28 | class MegatronTrainer(Trainer): 29 | """Main class for model parallel with data parallel training.""" 30 | 31 | def __init__(self, args, task, model, criterion): 32 | if not has_megatron_submodule: 33 | raise ImportError( 34 | "\n\nPlease install the megatron submodule:" 35 | "\n\n git submodule update --init " 36 | "fairseq/model_parallel/megatron" 37 | ) 38 | super().__init__(args, task, model, criterion) 39 | 40 | @property 41 | def data_parallel_world_size(self): 42 | return get_data_parallel_world_size() 43 | 44 | @property 45 | def data_parallel_process_group(self): 46 | return get_data_parallel_group() 47 | 48 | @property 49 | def data_parallel_rank(self): 50 | return get_data_parallel_rank() 51 | 52 | @property 53 | def is_data_parallel_master(self): 54 | return get_model_parallel_src_rank() == 0 55 | 56 | def clip_grad_norm(self, clip_norm): 57 | def _aggregate_model_parallel_grad_norm(total_norm): 58 | total_norm = total_norm ** 2 59 | distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) 60 | total_norm = total_norm ** 0.5 61 | return total_norm 62 | 63 | return self.optimizer.clip_grad_norm( 64 | clip_norm, 65 | aggregate_norm_fn=_aggregate_model_parallel_grad_norm, 66 | ) 67 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the models/ directory 11 | models_dir = os.path.dirname(__file__) 12 | for file in os.listdir(models_dir): 13 | path = os.path.join(models_dir, file) 14 | if ( 15 | not file.startswith("_") 16 | and not file.startswith(".") 17 | and (file.endswith(".py") or os.path.isdir(path)) 18 | ): 19 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 20 | module = importlib.import_module("fairseq.model_parallel.models." + model_name) 21 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import * # noqa 7 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import * # noqa 7 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | from .multihead_attention import ModelParallelMultiheadAttention 8 | from .transformer_layer import ( 9 | ModelParallelTransformerEncoderLayer, 10 | ModelParallelTransformerDecoderLayer, 11 | ) 12 | from .transformer_sentence_encoder_layer import ( 13 | ModelParallelTransformerSentenceEncoderLayer, 14 | ) 15 | from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder 16 | 17 | __all__ = [ 18 | "ModelParallelMultiheadAttention", 19 | "ModelParallelTransformerEncoderLayer", 20 | "ModelParallelTransformerDecoderLayer", 21 | "ModelParallelTransformerSentenceEncoder", 22 | "ModelParallelTransformerSentenceEncoderLayer", 23 | ] 24 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/modules/transformer_sentence_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer 13 | from fairseq.modules import ( 14 | LayerNorm, 15 | MultiheadAttention, 16 | PositionalEmbedding, 17 | TransformerSentenceEncoder, 18 | ) 19 | 20 | 21 | try: 22 | from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding 23 | 24 | has_megatron_submodule = True 25 | except (ImportError, ModuleNotFoundError): 26 | has_megatron_submodule = False 27 | 28 | 29 | class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): 30 | """ 31 | Implementation for a Model Parallel Bi-directional Transformer based 32 | Sentence Encoder used in BERT/XLM style pre-trained models. 33 | """ 34 | 35 | def build_embedding(self, vocab_size, embedding_dim, padding_idx): 36 | return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) 37 | 38 | def build_transformer_sentence_encoder_layer( 39 | self, 40 | embedding_dim, 41 | ffn_embedding_dim, 42 | num_attention_heads, 43 | dropout, 44 | attention_dropout, 45 | activation_dropout, 46 | activation_fn, 47 | export, 48 | **unused, 49 | ): 50 | return ModelParallelTransformerSentenceEncoderLayer( 51 | embedding_dim=embedding_dim, 52 | ffn_embedding_dim=ffn_embedding_dim, 53 | num_attention_heads=num_attention_heads, 54 | dropout=dropout, 55 | attention_dropout=attention_dropout, 56 | activation_dropout=activation_dropout, 57 | activation_fn=activation_fn, 58 | export=export, 59 | ) 60 | -------------------------------------------------------------------------------- /codes/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from fairseq import utils 9 | from fairseq.model_parallel.modules import ModelParallelMultiheadAttention 10 | from fairseq.modules import TransformerSentenceEncoderLayer 11 | 12 | 13 | try: 14 | from fairseq.model_parallel.megatron.mpu import ( 15 | ColumnParallelLinear, 16 | RowParallelLinear, 17 | ) 18 | 19 | has_megatron_submodule = True 20 | except (ImportError, ModuleNotFoundError): 21 | has_megatron_submodule = False 22 | 23 | 24 | class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): 25 | """ 26 | Implements a Model Parallel Transformer Encoder Layer used in 27 | BERT/XLM style pre-trained models. 28 | """ 29 | 30 | def build_fc1(self, input_dim, output_dim, **unused): 31 | return ColumnParallelLinear(input_dim, output_dim, gather_output=False) 32 | 33 | def build_fc2(self, input_dim, output_dim, **unused): 34 | return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) 35 | 36 | def build_self_attention( 37 | self, 38 | embed_dim, 39 | num_attention_heads, 40 | dropout, 41 | **kwargs, 42 | ): 43 | return ModelParallelMultiheadAttention( 44 | embed_dim, num_attention_heads, dropout=dropout, self_attention=True 45 | ) 46 | 47 | def forward( 48 | self, 49 | x: torch.Tensor, 50 | self_attn_mask: torch.Tensor = None, 51 | self_attn_padding_mask: torch.Tensor = None, 52 | ): 53 | """ 54 | LayerNorm is applied either before or after the self-attention/ffn 55 | modules similar to the original Transformer imlementation. 56 | """ 57 | residual = x 58 | x = self.self_attn_layer_norm(x) 59 | x, attn = self.self_attn( 60 | query=x, 61 | key=x, 62 | value=x, 63 | key_padding_mask=self_attn_padding_mask, 64 | need_weights=False, 65 | attn_mask=self_attn_mask, 66 | ) 67 | x = self.dropout_module(x) 68 | x = residual + x 69 | 70 | residual = x 71 | x = self.final_layer_norm(x) 72 | x = self.activation_fn(self.fc1(x)) 73 | x = self.activation_dropout_module(x) 74 | x = self.fc2(x) 75 | x = self.dropout_module(x) 76 | x = residual + x 77 | return x, None 78 | -------------------------------------------------------------------------------- /codes/fairseq/models/bart/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | -------------------------------------------------------------------------------- /codes/fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .fairseq_encoder import FairseqEncoder 7 | 8 | 9 | class CompositeEncoder(FairseqEncoder): 10 | """ 11 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 12 | 13 | We run forward on each encoder and return a dictionary of outputs. The first 14 | encoder's dictionary is used for initialization. 15 | 16 | Args: 17 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 18 | """ 19 | 20 | def __init__(self, encoders): 21 | super().__init__(next(iter(encoders.values())).dictionary) 22 | self.encoders = encoders 23 | for key in self.encoders: 24 | self.add_module(key, self.encoders[key]) 25 | 26 | def forward(self, src_tokens, src_lengths): 27 | """ 28 | Args: 29 | src_tokens (LongTensor): tokens in the source language of shape 30 | `(batch, src_len)` 31 | src_lengths (LongTensor): lengths of each source sentence of shape 32 | `(batch)` 33 | 34 | Returns: 35 | dict: 36 | the outputs from each Encoder 37 | """ 38 | encoder_out = {} 39 | for key in self.encoders: 40 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 41 | return encoder_out 42 | 43 | def reorder_encoder_out(self, encoder_out, new_order): 44 | """Reorder encoder output according to new_order.""" 45 | for key in self.encoders: 46 | encoder_out[key] = self.encoders[key].reorder_encoder_out( 47 | encoder_out[key], new_order 48 | ) 49 | return encoder_out 50 | 51 | def max_positions(self): 52 | return min(self.encoders[key].max_positions() for key in self.encoders) 53 | 54 | def upgrade_state_dict(self, state_dict): 55 | for key in self.encoders: 56 | self.encoders[key].upgrade_state_dict(state_dict) 57 | return state_dict 58 | -------------------------------------------------------------------------------- /codes/fairseq/models/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the models/huggingface/ directory 11 | models_dir = os.path.dirname(__file__) 12 | for file in os.listdir(models_dir): 13 | path = os.path.join(models_dir, file) 14 | if ( 15 | not file.startswith("_") 16 | and not file.startswith(".") 17 | and (file.endswith(".py") or os.path.isdir(path)) 18 | ): 19 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 20 | module = importlib.import_module("fairseq.models.huggingface." + model_name) 21 | -------------------------------------------------------------------------------- /codes/fairseq/models/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List, Optional 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | 12 | @torch.jit.script 13 | def script_skip_tensor_list(x: List[Tensor], mask): 14 | res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x] 15 | outputs = [] 16 | for i, t in enumerate(res): 17 | if t.numel() != 0: 18 | outputs.append(t) 19 | else: 20 | outputs.append(x[i]) 21 | return outputs 22 | 23 | 24 | @torch.jit.script 25 | def script_skip_tensor(x: Tensor, mask): 26 | # None case 27 | if x.size(0) == 0: 28 | return x 29 | res = x[mask] if x.size(0) == mask.size(0) else x[:, mask] 30 | if res.numel() == 0: 31 | return x 32 | else: 33 | return res 34 | 35 | 36 | @torch.jit.script 37 | def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): 38 | """ 39 | Expand 2D/3D tensor on dim=1 40 | """ 41 | if x is None: 42 | return None 43 | 44 | assert x.dim() == 2 or x.dim() == 3 45 | assert trg_dim >= x.size(1), (trg_dim, x.size()) 46 | if trg_dim == x.size(1): 47 | return x 48 | 49 | dims = [x.size(0), trg_dim - x.size(1)] 50 | if x.dim() == 3: 51 | dims.append(x.size(2)) 52 | x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1) 53 | 54 | return x 55 | 56 | 57 | @torch.jit.script 58 | def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: 59 | return x if x is not None else y 60 | 61 | 62 | @torch.jit.script 63 | def fill_tensors( 64 | x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int 65 | ) -> Optional[Tensor]: 66 | """ 67 | Filling tensor x with y at masked positions (dim=0). 68 | """ 69 | if x is None or x.size()[0] == 0 or y is None: 70 | return x 71 | assert x.dim() == y.dim() and mask.size(0) == x.size(0) 72 | assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) 73 | 74 | n_selected = mask.sum() 75 | if n_selected == 0: 76 | return x 77 | assert n_selected == y.size(0) 78 | if n_selected == x.size(0): 79 | return y 80 | 81 | if x.size(1) < y.size(1): 82 | x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) 83 | x[mask] = y 84 | elif x.size(1) > y.size(1): 85 | x[mask] = torch.tensor(padding_idx).type_as(x) 86 | if x.dim() == 2: 87 | x[mask, : y.size(1)] = y 88 | else: 89 | x[mask, : y.size(1), :] = y 90 | else: 91 | x[mask] = y 92 | return x 93 | -------------------------------------------------------------------------------- /codes/fairseq/models/nat/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | from .fairseq_nat_model import * 8 | from .nonautoregressive_transformer import * 9 | from .nat_crf_transformer import * 10 | from .iterative_nonautoregressive_transformer import * 11 | from .cmlm_transformer import * 12 | from .levenshtein_transformer import * 13 | from .insertion_transformer import * 14 | -------------------------------------------------------------------------------- /codes/fairseq/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | from .model_camembert import * # noqa 9 | from .model_xlmr import * # noqa 10 | -------------------------------------------------------------------------------- /codes/fairseq/models/roberta/model_camembert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | CamemBERT: a Tasty French Language Model 7 | """ 8 | 9 | from fairseq.models import register_model 10 | 11 | from .hub_interface import RobertaHubInterface 12 | from .model import RobertaModel 13 | 14 | 15 | @register_model("camembert") 16 | class CamembertModel(RobertaModel): 17 | @classmethod 18 | def hub_models(cls): 19 | return { 20 | "camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", 21 | "camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", 22 | "camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", 23 | "camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz", 24 | "camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz", 25 | "camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz", 26 | "camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz", 27 | "camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz", 28 | } 29 | 30 | @classmethod 31 | def from_pretrained( 32 | cls, 33 | model_name_or_path, 34 | checkpoint_file="model.pt", 35 | data_name_or_path=".", 36 | bpe="sentencepiece", 37 | **kwargs 38 | ): 39 | from fairseq import hub_utils 40 | 41 | x = hub_utils.from_pretrained( 42 | model_name_or_path, 43 | checkpoint_file, 44 | data_name_or_path, 45 | archive_map=cls.hub_models(), 46 | bpe=bpe, 47 | load_checkpoint_heads=True, 48 | **kwargs, 49 | ) 50 | return RobertaHubInterface(x["args"], x["task"], x["models"][0]) 51 | -------------------------------------------------------------------------------- /codes/fairseq/models/roberta/model_xlmr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Unsupervised Cross-lingual Representation Learning at Scale 7 | """ 8 | 9 | from fairseq.models import register_model 10 | 11 | from .hub_interface import RobertaHubInterface 12 | from .model import RobertaModel 13 | 14 | 15 | @register_model("xlmr") 16 | class XLMRModel(RobertaModel): 17 | @classmethod 18 | def hub_models(cls): 19 | return { 20 | "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz", 21 | "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz", 22 | } 23 | 24 | @classmethod 25 | def from_pretrained( 26 | cls, 27 | model_name_or_path, 28 | checkpoint_file="model.pt", 29 | data_name_or_path=".", 30 | bpe="sentencepiece", 31 | **kwargs 32 | ): 33 | from fairseq import hub_utils 34 | 35 | x = hub_utils.from_pretrained( 36 | model_name_or_path, 37 | checkpoint_file, 38 | data_name_or_path, 39 | archive_map=cls.hub_models(), 40 | bpe=bpe, 41 | load_checkpoint_heads=True, 42 | **kwargs, 43 | ) 44 | return RobertaHubInterface(x["args"], x["task"], x["models"][0]) 45 | -------------------------------------------------------------------------------- /codes/fairseq/models/speech_to_text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .berard import * # noqa 7 | from .s2t_transformer import * # noqa 8 | -------------------------------------------------------------------------------- /codes/fairseq/models/wav2vec/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .wav2vec import * # noqa 7 | from .wav2vec2 import * # noqa 8 | from .wav2vec2_asr import * # noqa 9 | -------------------------------------------------------------------------------- /codes/fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | from .adaptive_input import AdaptiveInput 8 | from .adaptive_softmax import AdaptiveSoftmax 9 | from .beamable_mm import BeamableMM 10 | from .character_token_embedder import CharacterTokenEmbedder 11 | from .conv_tbc import ConvTBC 12 | from .cross_entropy import cross_entropy 13 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 14 | from .dynamic_convolution import DynamicConv, DynamicConv1dTBC 15 | from .dynamic_crf_layer import DynamicCRF 16 | from .fairseq_dropout import FairseqDropout 17 | from .fp32_group_norm import Fp32GroupNorm 18 | from .gelu import gelu, gelu_accurate 19 | from .grad_multiply import GradMultiply 20 | from .gumbel_vector_quantizer import GumbelVectorQuantizer 21 | from .kmeans_vector_quantizer import KmeansVectorQuantizer 22 | from .layer_drop import LayerDropModuleList 23 | from .layer_norm import Fp32LayerNorm, LayerNorm 24 | from .learned_positional_embedding import LearnedPositionalEmbedding 25 | from .lightweight_convolution import LightweightConv, LightweightConv1dTBC 26 | from .linearized_convolution import LinearizedConvolution 27 | from .multihead_attention import MultiheadAttention 28 | from .positional_embedding import PositionalEmbedding 29 | from .same_pad import SamePad 30 | from .scalar_bias import ScalarBias 31 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 32 | from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer 33 | from .transformer_sentence_encoder import TransformerSentenceEncoder 34 | from .transpose_last import TransposeLast 35 | from .unfold import unfold1d 36 | from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer 37 | from .vggblock import VGGBlock 38 | 39 | __all__ = [ 40 | "AdaptiveInput", 41 | "AdaptiveSoftmax", 42 | "BeamableMM", 43 | "CharacterTokenEmbedder", 44 | "ConvTBC", 45 | "cross_entropy", 46 | "DownsampledMultiHeadAttention", 47 | "DynamicConv1dTBC", 48 | "DynamicConv", 49 | "DynamicCRF", 50 | "FairseqDropout", 51 | "Fp32GroupNorm", 52 | "Fp32LayerNorm", 53 | "gelu", 54 | "gelu_accurate", 55 | "GradMultiply", 56 | "GumbelVectorQuantizer", 57 | "KmeansVectorQuantizer", 58 | "LayerDropModuleList", 59 | "LayerNorm", 60 | "LearnedPositionalEmbedding", 61 | "LightweightConv1dTBC", 62 | "LightweightConv", 63 | "LinearizedConvolution", 64 | "MultiheadAttention", 65 | "PositionalEmbedding", 66 | "SamePad", 67 | "ScalarBias", 68 | "SinusoidalPositionalEmbedding", 69 | "TransformerSentenceEncoderLayer", 70 | "TransformerSentenceEncoder", 71 | "TransformerDecoderLayer", 72 | "TransformerEncoderLayer", 73 | "TransposeLast", 74 | "VGGBlock", 75 | "unfold1d", 76 | ] 77 | -------------------------------------------------------------------------------- /codes/fairseq/modules/adaptive_input.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from typing import List 8 | 9 | import torch 10 | from fairseq.modules.quant_noise import quant_noise 11 | from torch import nn 12 | 13 | 14 | class AdaptiveInput(nn.Module): 15 | def __init__( 16 | self, 17 | vocab_size: int, 18 | padding_idx: int, 19 | initial_dim: int, 20 | factor: float, 21 | output_dim: int, 22 | cutoff: List[int], 23 | q_noise: float = 0, 24 | qn_block_size: int = 8, 25 | ): 26 | super().__init__() 27 | 28 | if vocab_size > cutoff[-1]: 29 | cutoff = cutoff + [vocab_size] 30 | else: 31 | assert ( 32 | vocab_size == cutoff[-1] 33 | ), "cannot specify cutoff larger than vocab size" 34 | 35 | self.cutoff = cutoff 36 | self.embedding_dim = output_dim 37 | self.padding_idx = padding_idx 38 | 39 | self.embeddings = nn.ModuleList() 40 | for i in range(len(self.cutoff)): 41 | prev = self.cutoff[i - 1] if i > 0 else 0 42 | size = self.cutoff[i] - prev 43 | dim = int(initial_dim // (factor ** i)) 44 | seq = nn.Sequential( 45 | nn.Embedding(size, dim, self.padding_idx), 46 | quant_noise( 47 | nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size 48 | ), 49 | ) 50 | 51 | self.embeddings.append(seq) 52 | self.padding_idx = None 53 | self.padding_idx = padding_idx 54 | 55 | def init_weights(m): 56 | if isinstance(m, nn.Embedding): 57 | nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) 58 | nn.init.constant_(m.weight[padding_idx], 0) 59 | elif hasattr(m, "weight"): 60 | nn.init.xavier_uniform_(m.weight) 61 | 62 | self.apply(init_weights) 63 | 64 | self.register_buffer("_float_tensor", torch.FloatTensor(1)) 65 | 66 | def weights_for_band(self, band: int): 67 | return self.embeddings[band][0].weight, self.embeddings[band][1].weight 68 | 69 | def forward(self, input: torch.Tensor): 70 | result = self._float_tensor.new(input.shape + (self.embedding_dim,)) 71 | for i in range(len(self.cutoff)): 72 | mask = input.lt(self.cutoff[i]) 73 | if i > 0: 74 | mask.mul_(input.ge(self.cutoff[i - 1])) 75 | chunk_input = input[mask] - self.cutoff[i - 1] 76 | else: 77 | chunk_input = input[mask] 78 | if mask.any(): 79 | result[mask] = self.embeddings[i](chunk_input) 80 | return result 81 | -------------------------------------------------------------------------------- /codes/fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class BeamableMM(nn.Module): 11 | """This module provides an optimized MM for beam decoding with attention. 12 | 13 | It leverage the fact that the source-side of the input is replicated beam 14 | times and the target-side of the input is of width one. This layer speeds up 15 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 16 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 17 | """ 18 | 19 | def __init__(self, beam_size=None): 20 | super(BeamableMM, self).__init__() 21 | self.beam_size = beam_size 22 | 23 | def forward(self, input1, input2): 24 | if ( 25 | not self.training 26 | and self.beam_size is not None # test mode 27 | and input1.dim() == 3 # beam size is set 28 | and input1.size(1) # only support batched input 29 | == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /codes/fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.nn.modules.utils import _single 8 | 9 | 10 | class ConvTBC(torch.nn.Module): 11 | """1D convolution over an input of shape (time x batch x channel) 12 | 13 | The implementation uses gemm to perform the convolution. This implementation 14 | is faster than cuDNN for small kernel sizes. 15 | """ 16 | 17 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 18 | super(ConvTBC, self).__init__() 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.kernel_size = _single(kernel_size) 22 | self.padding = _single(padding) 23 | 24 | self.weight = torch.nn.Parameter( 25 | torch.Tensor(self.kernel_size[0], in_channels, out_channels) 26 | ) 27 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 28 | 29 | def forward(self, input): 30 | return torch.conv_tbc( 31 | input.contiguous(), self.weight, self.bias, self.padding[0] 32 | ) 33 | 34 | def __repr__(self): 35 | s = ( 36 | "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}" 37 | ", padding={padding}" 38 | ) 39 | if self.bias is None: 40 | s += ", bias=False" 41 | s += ")" 42 | return s.format(name=self.__class__.__name__, **self.__dict__) 43 | -------------------------------------------------------------------------------- /codes/fairseq/modules/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): 16 | lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) 17 | return F.nll_loss( 18 | lprobs, 19 | target, 20 | ignore_index=ignore_index, 21 | reduction=reduction, 22 | ) 23 | 24 | 25 | try: 26 | import xentropy_cuda 27 | from apex.contrib import xentropy 28 | 29 | logger.info("using fused cross entropy") 30 | 31 | def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): 32 | if logits.device == torch.device("cpu"): 33 | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) 34 | else: 35 | half_to_float = logits.dtype == torch.half 36 | losses = xentropy.SoftmaxCrossEntropyLoss.apply( 37 | logits, 38 | target, 39 | 0.0, 40 | ignore_index, 41 | half_to_float, 42 | ) 43 | if reduction == "sum": 44 | return losses.sum() 45 | elif reduction == "mean": 46 | if ignore_index >= 0: 47 | return losses.sum() / target.ne(ignore_index).sum() 48 | else: 49 | return losses.mean() 50 | elif reduction == "none": 51 | return losses 52 | else: 53 | raise NotImplementedError 54 | 55 | 56 | except ImportError: 57 | 58 | def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): 59 | return _cross_entropy_pytorch(logits, target, ignore_index, reduction) 60 | -------------------------------------------------------------------------------- /codes/fairseq/modules/dynamicconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamicconv_layer import DynamicconvLayer # noqa 7 | -------------------------------------------------------------------------------- /codes/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector dynamicconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector dynamicconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector dynamicconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return dynamicconv_cuda_forward(input, filters, 36 | padding_l); 37 | } 38 | 39 | std::vector dynamicconv_backward( 40 | at::Tensor gradOutput, 41 | int padding_l, 42 | at::Tensor input, 43 | at::Tensor filters) { 44 | 45 | CHECK_INPUT(gradOutput); 46 | CHECK_INPUT(input); 47 | CHECK_INPUT(filters); 48 | 49 | return dynamicconv_cuda_backward(gradOutput, padding_l, 50 | input, filters); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); 55 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); 56 | } 57 | -------------------------------------------------------------------------------- /codes/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #define SHFL_MASK 0xffffffff 27 | 28 | template 29 | __global__ 30 | void dynamicconv_forward_kernel(const scalar_t* input, 31 | const scalar_t* weight, 32 | int minibatch, 33 | int sequenceLength, 34 | int numFeatures, 35 | int numFiltersInBlock, 36 | int numHeads, 37 | scalar_t* output); 38 | 39 | template 40 | __global__ 41 | void dynamicconv_backward_kernel( 42 | const scalar_t* gradOutput, // B * C * T 43 | const scalar_t* input, // B * C * T 44 | const scalar_t* weight, 45 | int minibatch, 46 | int sequenceLength, 47 | int numFeatures, 48 | int numFiltersInBlock, 49 | int numHeads, 50 | scalar_t* gradWeight, 51 | scalar_t* gradInput); // B * H * k * T 52 | -------------------------------------------------------------------------------- /codes/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector dynamicconv_cpu_forward( 5 | float* input, 6 | float* filters, 7 | int padding_l); 8 | 9 | std::vector dynamicconv_cpu_backward( 10 | float* gradOutput, 11 | int padding_l, 12 | float* input, 13 | float* filters); 14 | 15 | std::vector dynamicconv_forward( 16 | float* input, 17 | float* filters, 18 | int padding_l) { 19 | 20 | return dynamicconv_cpu_forward(input, filters, padding_l); 21 | } 22 | 23 | std::vector dynamicconv_backward( 24 | float* gradOutput, 25 | int padding_l, 26 | float* input, 27 | float* filters) { 28 | 29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); 34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); 35 | } 36 | -------------------------------------------------------------------------------- /codes/fairseq/modules/dynamicconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | 11 | setup( 12 | name="dynamicconv_layer", 13 | ext_modules=[ 14 | CUDAExtension( 15 | name="dynamicconv_cuda", 16 | sources=[ 17 | "dynamicconv_cuda.cpp", 18 | "dynamicconv_cuda_kernel.cu", 19 | ], 20 | ), 21 | ], 22 | cmdclass={"build_ext": BuildExtension}, 23 | ) 24 | -------------------------------------------------------------------------------- /codes/fairseq/modules/fairseq_dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List, Optional 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class FairseqDropout(nn.Module): 17 | def __init__(self, p, module_name=None): 18 | super().__init__() 19 | self.p = p 20 | self.module_name = module_name 21 | self.apply_during_inference = False 22 | 23 | def forward(self, x, inplace: bool = False): 24 | if self.training or self.apply_during_inference: 25 | return F.dropout(x, p=self.p, training=True, inplace=inplace) 26 | else: 27 | return x 28 | 29 | def make_generation_fast_( 30 | self, 31 | name: str, 32 | retain_dropout: bool = False, 33 | retain_dropout_modules: Optional[List[str]] = None, 34 | **kwargs 35 | ): 36 | if retain_dropout: 37 | if retain_dropout_modules is not None and self.module_name is None: 38 | logger.warning( 39 | "Cannot enable dropout during inference for module {} " 40 | "because module_name was not set".format(name) 41 | ) 42 | elif ( 43 | retain_dropout_modules is None # if None, apply to all modules 44 | or self.module_name in retain_dropout_modules 45 | ): 46 | logger.info( 47 | "Enabling dropout during inference for module: {}".format(name) 48 | ) 49 | self.apply_during_inference = True 50 | else: 51 | logger.info("Disabling dropout for module: {}".format(name)) 52 | -------------------------------------------------------------------------------- /codes/fairseq/modules/fp32_group_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Layer norm done in fp32 (for fp16 training) 7 | """ 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Fp32GroupNorm(nn.GroupNorm): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | def forward(self, input): 18 | output = F.group_norm( 19 | input.float(), 20 | self.num_groups, 21 | self.weight.float() if self.weight is not None else None, 22 | self.bias.float() if self.bias is not None else None, 23 | self.eps, 24 | ) 25 | return output.type_as(input) 26 | -------------------------------------------------------------------------------- /codes/fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def gelu_accurate(x): 17 | if not hasattr(gelu_accurate, "_a"): 18 | gelu_accurate._a = math.sqrt(2 / math.pi) 19 | return ( 20 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 21 | ) 22 | 23 | 24 | def gelu(x: torch.Tensor) -> torch.Tensor: 25 | return torch.nn.functional.gelu(x.float()).type_as(x) 26 | -------------------------------------------------------------------------------- /codes/fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class GradMultiply(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.scale = scale 13 | res = x.new(x) 14 | return res 15 | 16 | @staticmethod 17 | def backward(ctx, grad): 18 | return grad * ctx.scale, None 19 | -------------------------------------------------------------------------------- /codes/fairseq/modules/layer_drop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | LayerDrop as described in https://arxiv.org/abs/1909.11556. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class LayerDropModuleList(nn.ModuleList): 14 | """ 15 | A LayerDrop implementation based on :class:`torch.nn.ModuleList`. 16 | 17 | We refresh the choice of which layers to drop every time we iterate 18 | over the LayerDropModuleList instance. During evaluation we always 19 | iterate over all layers. 20 | 21 | Usage:: 22 | 23 | layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) 24 | for layer in layers: # this might iterate over layers 1 and 3 25 | x = layer(x) 26 | for layer in layers: # this might iterate over all layers 27 | x = layer(x) 28 | for layer in layers: # this might not iterate over any layers 29 | x = layer(x) 30 | 31 | Args: 32 | p (float): probability of dropping out each layer 33 | modules (iterable, optional): an iterable of modules to add 34 | """ 35 | 36 | def __init__(self, p, modules=None): 37 | super().__init__(modules) 38 | self.p = p 39 | 40 | def __iter__(self): 41 | dropout_probs = torch.empty(len(self)).uniform_() 42 | for i, m in enumerate(super().__iter__()): 43 | if not self.training or (dropout_probs[i] > self.p): 44 | yield m 45 | -------------------------------------------------------------------------------- /codes/fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | try: 12 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm 13 | 14 | has_fused_layernorm = True 15 | 16 | class FusedLayerNorm(_FusedLayerNorm): 17 | @torch.jit.unused 18 | def forward(self, x): 19 | if not x.is_cuda: 20 | return super().forward(x) 21 | else: 22 | with torch.cuda.device(x.device): 23 | return super().forward(x) 24 | 25 | 26 | except ImportError: 27 | has_fused_layernorm = False 28 | 29 | 30 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 31 | if torch.jit.is_scripting(): 32 | export = True 33 | if not export and torch.cuda.is_available() and has_fused_layernorm: 34 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 35 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 36 | 37 | 38 | class Fp32LayerNorm(nn.LayerNorm): 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | 42 | def forward(self, input): 43 | output = F.layer_norm( 44 | input.float(), 45 | self.normalized_shape, 46 | self.weight.float() if self.weight is not None else None, 47 | self.bias.float() if self.bias is not None else None, 48 | self.eps, 49 | ) 50 | return output.type_as(input) 51 | -------------------------------------------------------------------------------- /codes/fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Dict, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from fairseq import utils 12 | from torch import Tensor 13 | 14 | 15 | class LearnedPositionalEmbedding(nn.Embedding): 16 | """ 17 | This module learns positional embeddings up to a fixed maximum size. 18 | Padding ids are ignored by either offsetting based on padding_idx 19 | or by setting padding_idx to None and ensuring that the appropriate 20 | position ids are passed to the forward function. 21 | """ 22 | 23 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): 24 | super().__init__(num_embeddings, embedding_dim, padding_idx) 25 | self.onnx_trace = False 26 | if self.padding_idx is not None: 27 | self.max_positions = self.num_embeddings - self.padding_idx - 1 28 | else: 29 | self.max_positions = self.num_embeddings 30 | 31 | def forward( 32 | self, 33 | input: Tensor, 34 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 35 | positions: Optional[Tensor] = None, 36 | ): 37 | """Input is expected to be of size [bsz x seqlen].""" 38 | assert (positions is None) or ( 39 | self.padding_idx is None 40 | ), "If positions is pre-computed then padding_idx should not be set." 41 | 42 | if positions is None: 43 | if incremental_state is not None: 44 | # positions is the same for every token when decoding a single step 45 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 46 | positions = torch.zeros( 47 | (1, 1), device=input.device, dtype=input.dtype 48 | ).fill_(int(self.padding_idx + input.size(1))) 49 | else: 50 | positions = utils.make_positions( 51 | input, self.padding_idx, onnx_trace=self.onnx_trace 52 | ) 53 | return F.embedding( 54 | positions, 55 | self.weight, 56 | self.padding_idx, 57 | self.max_norm, 58 | self.norm_type, 59 | self.scale_grad_by_freq, 60 | self.sparse, 61 | ) 62 | -------------------------------------------------------------------------------- /codes/fairseq/modules/lightconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lightconv_layer import LightconvLayer # noqa 7 | -------------------------------------------------------------------------------- /codes/fairseq/modules/lightconv_layer/lightconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector lightconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector lightconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector lightconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return lightconv_cuda_forward(input, filters, padding_l); 36 | } 37 | 38 | std::vector lightconv_backward( 39 | at::Tensor gradOutput, 40 | int padding_l, 41 | at::Tensor input, 42 | at::Tensor filters) { 43 | 44 | CHECK_INPUT(gradOutput); 45 | CHECK_INPUT(input); 46 | CHECK_INPUT(filters); 47 | 48 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); 53 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /codes/fairseq/modules/lightconv_layer/lightconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | 24 | #define SHFL_MASK 0xffffffff 25 | 26 | template 27 | __global__ 28 | void lightconv_forward_kernel(const scalar_t* input, 29 | const scalar_t* filters, 30 | int minibatch, int sequenceLength, 31 | int numFeatures, int numFiltersInBlock, 32 | scalar_t* output); 33 | 34 | template 35 | __global__ 36 | void lightconv_grad_wrt_input_kernel( 37 | const scalar_t* input, 38 | const scalar_t* filters, 39 | int minibatch, 40 | int sequenceLength, 41 | int numFeatures, 42 | int numFiltersInBlock, 43 | scalar_t* output); 44 | 45 | template 46 | __global__ 47 | void lightconv_grad_wrt_weights_firstpass_short_kernel( 48 | const scalar_t* input, 49 | const scalar_t* gradInput, 50 | int minibatch, 51 | int sequenceLength, 52 | int numFeatures, 53 | int numFiltersInBlock, 54 | int numHeads, 55 | float* output); 56 | 57 | template 58 | __global__ 59 | void lightconv_grad_wrt_weights_secondpass_short_kernel( 60 | const float* input, 61 | const int minibatch, 62 | const int numFiltersInBlock, 63 | scalar_t* output); 64 | 65 | template 66 | __global__ 67 | void lightconv_grad_wrt_weights_firstpass_kernel( 68 | const scalar_t* input, 69 | const scalar_t* gradInput, 70 | int minibatch, 71 | int sequenceLength, 72 | int numFeatures, 73 | int numFiltersInBlock, 74 | float* output); 75 | 76 | template 77 | __global__ 78 | void lightconv_grad_wrt_weights_secondpass_kernel( 79 | const float* input, 80 | const int minibatch, 81 | const int numFiltersInBlock, 82 | scalar_t* output); 83 | 84 | -------------------------------------------------------------------------------- /codes/fairseq/modules/lightconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | 11 | setup( 12 | name="lightconv_layer", 13 | ext_modules=[ 14 | CUDAExtension( 15 | "lightconv_cuda", 16 | [ 17 | "lightconv_cuda.cpp", 18 | "lightconv_cuda_kernel.cu", 19 | ], 20 | ), 21 | ], 22 | cmdclass={"build_ext": BuildExtension}, 23 | ) 24 | -------------------------------------------------------------------------------- /codes/fairseq/modules/positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from .learned_positional_embedding import LearnedPositionalEmbedding 9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 10 | 11 | 12 | def PositionalEmbedding( 13 | num_embeddings: int, 14 | embedding_dim: int, 15 | padding_idx: int, 16 | learned: bool = False, 17 | ): 18 | if learned: 19 | # if padding_idx is specified then offset the embedding ids by 20 | # this index and adjust num_embeddings appropriately 21 | # TODO: The right place for this offset would be inside 22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation. 23 | if padding_idx is not None: 24 | num_embeddings = num_embeddings + padding_idx + 1 25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 27 | if padding_idx is not None: 28 | nn.init.constant_(m.weight[padding_idx], 0) 29 | else: 30 | m = SinusoidalPositionalEmbedding( 31 | embedding_dim, 32 | padding_idx, 33 | init_size=num_embeddings + padding_idx + 1, 34 | ) 35 | return m 36 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/modules/quantization/__init__.py -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/pq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import SizeTracker, quantize_model_ # NOQA 7 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/pq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qconv import PQConv2d # NOQA 7 | from .qemb import PQEmbedding # NOQA 8 | from .qlinear import PQLinear # NOQA 9 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/pq/modules/qlinear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class PQLinear(nn.Module): 12 | """ 13 | Quantized counterpart of nn.Linear module. Stores the centroid, the assignments 14 | and the non-quantized biases. The full weight is re-instantiated at each forward 15 | pass. 16 | 17 | Args: 18 | - centroids: centroids of size n_centroids x block_size 19 | - assignments: assignments of the centroids to the subvectors 20 | of size self.out_features x n_blocks 21 | - bias: the non-quantized bias 22 | 23 | Remarks: 24 | - We refer the reader to the official documentation of the nn.Linear module 25 | for the other arguments and the behavior of the module 26 | - Performance tests on GPU show that this implementation is 15% slower than 27 | the non-quantized nn.Linear module for a standard training loop. 28 | """ 29 | 30 | def __init__(self, centroids, assignments, bias, in_features, out_features): 31 | super(PQLinear, self).__init__() 32 | self.block_size = centroids.size(1) 33 | self.n_centroids = centroids.size(0) 34 | self.in_features = in_features 35 | self.out_features = out_features 36 | # check compatibility 37 | if self.in_features % self.block_size != 0: 38 | raise ValueError("Wrong PQ sizes") 39 | if len(assignments) % self.out_features != 0: 40 | raise ValueError("Wrong PQ sizes") 41 | # define parameters 42 | self.centroids = nn.Parameter(centroids, requires_grad=True) 43 | self.register_buffer("assignments", assignments) 44 | self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) 45 | if bias is not None: 46 | self.bias = nn.Parameter(bias) 47 | else: 48 | self.register_parameter("bias", None) 49 | 50 | @property 51 | def weight(self): 52 | return ( 53 | self.centroids[self.assignments] 54 | .reshape(-1, self.out_features, self.block_size) 55 | .permute(1, 0, 2) 56 | .flatten(1, 2) 57 | ) 58 | 59 | def forward(self, x): 60 | return F.linear( 61 | x, 62 | self.weight, 63 | self.bias, 64 | ) 65 | 66 | def extra_repr(self): 67 | return f"in_features={self.in_features},\ 68 | out_features={self.out_features},\ 69 | n_centroids={self.n_centroids},\ 70 | block_size={self.block_size},\ 71 | bias={self.bias is not None}" 72 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/quantization_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | def parse_config_yaml(yaml_data): 8 | # Initialize to default options. 9 | quantization_options = { 10 | "n_centroids": { 11 | "Linear": ["in_features", {"*": 256}], 12 | "Embedding": ["embedding_dim", {"*": 256}], 13 | }, 14 | "block_sizes": { 15 | "Linear": ["fuzzy_name", {"fc": 8, "attn": 4, "emb": 4}], 16 | "Embedding": ["fuzzy_name", {"emb": 8}], 17 | }, 18 | "layers_to_quantize": [ 19 | "decoder\\.layers\\.\\d+\\.fc[12]", 20 | "decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]", 21 | "decoder\\.layers\\.\\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)", 22 | ], 23 | } 24 | 25 | if "n_centroids" in yaml_data: 26 | quantization_options["n_centroids"] = { 27 | layer: convert_yaml_to_tuple(layer_data) 28 | for layer, layer_data in yaml_data["n_centroids"].items() 29 | } 30 | if "block_sizes" in yaml_data: 31 | quantization_options["block_sizes"] = { 32 | layer: convert_yaml_to_tuple(layer_data) 33 | for layer, layer_data in yaml_data["block_sizes"].items() 34 | } 35 | if "layers_to_quantize" in yaml_data: 36 | quantization_options["layers_to_quantize"] = yaml_data["layers_to_quantize"] 37 | 38 | return quantization_options 39 | 40 | 41 | def convert_yaml_to_tuple(yaml_dictionary): 42 | """Converts a yaml dictionary with two keys: `key` and `value` into a two 43 | argument tuple of those values.""" 44 | return (yaml_dictionary["key"], yaml_dictionary["value"]) 45 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import quantize_model_ # NOQA 7 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/scalar/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qact import ActivationQuantizer # NOQA 7 | from .qconv import IntConv2d # NOQA 8 | from .qemb import IntEmbedding # NOQA 9 | from .qlinear import IntLinear # NOQA 10 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/scalar/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | def emulate_int(w, bits, method, scale=None, zero_point=None): 10 | q = globals()[f"emulate_int{bits}_{method}"] 11 | return q(w, scale=scale, zero_point=zero_point) 12 | 13 | 14 | def quantize(w, scale, zero_point): 15 | return ( 16 | torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point 17 | ) * scale 18 | 19 | 20 | def emulate_int8_histogram(w, scale=None, zero_point=None): 21 | if scale is None: 22 | obs = torch.quantization.observer.HistogramObserver() 23 | _ = obs(w.float()) 24 | scale, zero_point = obs.calculate_qparams() 25 | scale = scale.cuda().type_as(w) 26 | zero_point = zero_point.cuda().type_as(w) 27 | return quantize(w, scale, zero_point), scale, zero_point 28 | 29 | 30 | def emulate_int8_channel(w, scale=None, zero_point=None): 31 | if scale is None: 32 | obs = torch.quantization.observer.PerChannelMinMaxObserver( 33 | ch_axis=-1, qscheme=torch.per_channel_symmetric 34 | ) 35 | _ = obs(w) 36 | scale, zero_point, ch_axis = obs.get_qparams() 37 | scale = scale.cuda().type_as(w) 38 | zero_point = zero_point.cuda().type_as(w) 39 | return quantize(w, scale, zero_point), scale, zero_point 40 | 41 | 42 | def emulate_int8_tensor(w, scale=None, zero_point=None): 43 | if scale is None: 44 | obs = torch.quantization.observer.MinMaxObserver() 45 | _ = obs(w) 46 | scale, zero_point = obs.calculate_qparams() 47 | scale = scale.cuda().type_as(w) 48 | zero_point = zero_point.cuda().type_as(w) 49 | return quantize(w, scale, zero_point), scale, zero_point 50 | -------------------------------------------------------------------------------- /codes/fairseq/modules/quantization/scalar/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from operator import attrgetter 8 | 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | 12 | from ..pq.utils import attrsetter, get_layers 13 | from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear 14 | 15 | 16 | MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} 17 | 18 | 19 | def quantize_model_(model, p=0.2, bits=8, update_step=3000): 20 | """ 21 | Replaces all modules with their scalar quantized counterpart and 22 | registers hooks to quantize the post-ativations of those modules. 23 | 24 | Args: 25 | - model: a nn.Module 26 | - p: amount of noise (0 for no noise, 1 to quantize all the weights/activations) 27 | - bits: number of bits 28 | - update_step: update quantization parameters every update_step steps 29 | """ 30 | 31 | # quantize all layers 32 | quantized_layers = get_layers(model, "(.*?)") 33 | 34 | for layer in quantized_layers: 35 | 36 | # book-keeping 37 | is_master_process = (not dist.is_initialized()) or ( 38 | dist.is_initialized() and dist.get_rank() == 0 39 | ) 40 | 41 | # recover module 42 | module = attrgetter(layer)(model) 43 | if is_master_process: 44 | logging.info( 45 | f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}" 46 | ) 47 | 48 | # quantization params 49 | q_params = { 50 | "p": p, 51 | "update_step": update_step, 52 | "bits": bits, 53 | "method": "histogram", 54 | "counter": 0, 55 | } 56 | 57 | # instantiate the quantized counterpart 58 | if isinstance(module, tuple(MAPPING.keys())): 59 | QuantizedModule = MAPPING[module.__class__] 60 | quantized_module = QuantizedModule.__new__(QuantizedModule) 61 | params = module.__dict__ 62 | params.update(q_params) 63 | quantized_module.__dict__.update(params) 64 | 65 | else: 66 | if is_master_process: 67 | logging.info(f"Module {module} not yet supported for quantization") 68 | continue 69 | 70 | # activation quantization 71 | a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram") 72 | 73 | # replace layer by its quantized counterpart 74 | attrsetter(layer)(model, quantized_module) 75 | 76 | # return name of quantized layers 77 | return quantized_layers 78 | -------------------------------------------------------------------------------- /codes/fairseq/modules/same_pad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from torch import nn 8 | 9 | 10 | class SamePad(nn.Module): 11 | def __init__(self, kernel_size): 12 | super().__init__() 13 | self.remove = kernel_size % 2 == 0 14 | 15 | def forward(self, x): 16 | if self.remove: 17 | x = x[:, :, :-1] 18 | return x 19 | -------------------------------------------------------------------------------- /codes/fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import torch 8 | 9 | 10 | class ScalarBias(torch.autograd.Function): 11 | """ 12 | Adds a vector of scalars, used in self-attention mechanism to allow 13 | the model to optionally attend to this vector instead of the past 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, input, dim, bias_init): 18 | size = list(input.size()) 19 | size[dim] += 1 20 | output = input.new(*size).fill_(bias_init) 21 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 22 | ctx.dim = dim 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad): 27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 28 | 29 | 30 | def scalar_bias(input, dim, bias_init=0): 31 | return ScalarBias.apply(input, dim, bias_init) 32 | -------------------------------------------------------------------------------- /codes/fairseq/modules/sparse_transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.modules import TransformerSentenceEncoderLayer 7 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention 8 | 9 | 10 | class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): 11 | """ 12 | Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) 13 | """ 14 | 15 | def __init__( 16 | self, 17 | embedding_dim: int = 768, 18 | ffn_embedding_dim: int = 3072, 19 | num_attention_heads: int = 8, 20 | dropout: float = 0.1, 21 | attention_dropout: float = 0.1, 22 | activation_dropout: float = 0.1, 23 | activation_fn: str = "relu", 24 | export: bool = False, 25 | is_bidirectional: bool = True, 26 | stride: int = 32, 27 | expressivity: int = 8, 28 | ) -> None: 29 | 30 | super().__init__( 31 | embedding_dim, 32 | ffn_embedding_dim, 33 | num_attention_heads, 34 | dropout, 35 | attention_dropout, 36 | activation_dropout, 37 | activation_fn, 38 | export, 39 | ) 40 | 41 | self.self_attn = SparseMultiheadAttention( 42 | self.embedding_dim, 43 | num_attention_heads, 44 | dropout=attention_dropout, 45 | add_bias_kv=False, 46 | add_zero_attn=False, 47 | self_attention=True, 48 | is_bidirectional=is_bidirectional, 49 | stride=stride, 50 | expressivity=expressivity, 51 | ) 52 | -------------------------------------------------------------------------------- /codes/fairseq/modules/transpose_last.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | transpose last 2 dimensions of the input 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | 12 | class TransposeLast(nn.Module): 13 | def __init__(self, deconstruct_idx=None): 14 | super().__init__() 15 | self.deconstruct_idx = deconstruct_idx 16 | 17 | def forward(self, x): 18 | if self.deconstruct_idx is not None: 19 | x = x[self.deconstruct_idx] 20 | return x.transpose(-2, -1) 21 | -------------------------------------------------------------------------------- /codes/fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 10 | """unfold T x B x C to T x B x C x K""" 11 | if kernel_size > 1: 12 | T, B, C = x.size() 13 | x = F.pad( 14 | x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value 15 | ) 16 | x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C)) 17 | else: 18 | x = x.unsqueeze(3) 19 | return x 20 | -------------------------------------------------------------------------------- /codes/fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | from argparse import Namespace 10 | from typing import Union 11 | 12 | from fairseq import registry 13 | from fairseq.optim.bmuf import FairseqBMUF # noqa 14 | from fairseq.optim.fairseq_optimizer import ( # noqa 15 | FairseqOptimizer, 16 | LegacyFairseqOptimizer, 17 | ) 18 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 19 | from fairseq.optim.shard import shard_ 20 | from omegaconf import DictConfig 21 | 22 | 23 | __all__ = [ 24 | "FairseqOptimizer", 25 | "FP16Optimizer", 26 | "MemoryEfficientFP16Optimizer", 27 | "shard_", 28 | ] 29 | 30 | 31 | ( 32 | _build_optimizer, 33 | register_optimizer, 34 | OPTIMIZER_REGISTRY, 35 | OPTIMIZER_DATACLASS_REGISTRY, 36 | ) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) 37 | 38 | 39 | def build_optimizer( 40 | optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs 41 | ): 42 | if all(isinstance(p, dict) for p in params): 43 | params = [t for p in params for t in p.values()] 44 | params = list(filter(lambda p: p.requires_grad, params)) 45 | return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) 46 | 47 | 48 | # automatically import any Python files in the optim/ directory 49 | for file in os.listdir(os.path.dirname(__file__)): 50 | if file.endswith(".py") and not file.startswith("_"): 51 | file_name = file[: file.find(".py")] 52 | importlib.import_module("fairseq.optim." + file_name) 53 | -------------------------------------------------------------------------------- /codes/fairseq/optim/adadelta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import LegacyFairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer("adadelta") 12 | class Adadelta(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', 22 | help='coefficient used for computing a running average of squared gradients') 23 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', 24 | help='term added to the denominator to improve numerical stability') 25 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 26 | help='weight decay') 27 | parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') 28 | # fmt: on 29 | 30 | @property 31 | def optimizer_config(self): 32 | """ 33 | Return a kwarg dictionary that will be used to override optimizer 34 | args stored in checkpoints. This allows us to load a checkpoint and 35 | resume training using a different set of optimizer args, e.g., with a 36 | different learning rate. 37 | """ 38 | return { 39 | "lr": self.args.lr[0], 40 | "rho": self.args.adadelta_rho, 41 | "eps": self.args.adadelta_eps, 42 | "weight_decay": self.args.weight_decay, 43 | } 44 | 45 | @property 46 | def supports_flat_params(self): 47 | return True 48 | -------------------------------------------------------------------------------- /codes/fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import LegacyFairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer("adagrad") 12 | class Adagrad(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 22 | help='weight decay') 23 | # fmt: on 24 | 25 | @property 26 | def optimizer_config(self): 27 | """ 28 | Return a kwarg dictionary that will be used to override optimizer 29 | args stored in checkpoints. This allows us to load a checkpoint and 30 | resume training using a different set of optimizer args, e.g., with a 31 | different learning rate. 32 | """ 33 | return { 34 | "lr": self.args.lr[0], 35 | "weight_decay": self.args.weight_decay, 36 | } 37 | 38 | @property 39 | def supports_flat_params(self): 40 | return True 41 | -------------------------------------------------------------------------------- /codes/fairseq/optim/dynamic_loss_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | class DynamicLossScaler(object): 8 | def __init__( 9 | self, 10 | init_scale=2.0 ** 15, 11 | scale_factor=2.0, 12 | scale_window=2000, 13 | tolerance=0.05, 14 | threshold=None, 15 | min_loss_scale=1e-4, 16 | ): 17 | self.loss_scale = init_scale 18 | self.scale_factor = scale_factor 19 | self.scale_window = scale_window 20 | self.tolerance = tolerance 21 | self.threshold = threshold 22 | self._iter = 0 23 | self._last_overflow_iter = -1 24 | self._last_rescale_iter = -1 25 | self._overflows_since_rescale = 0 26 | self.min_loss_scale = min_loss_scale 27 | 28 | def scale(self, outputs): 29 | return self.loss_scale * outputs 30 | 31 | def update(self): 32 | if (self._iter - self._last_overflow_iter) % self.scale_window == 0: 33 | self.loss_scale *= self.scale_factor 34 | self._last_rescale_iter = self._iter 35 | self._iter += 1 36 | 37 | def _decrease_loss_scale(self): 38 | self.loss_scale /= self.scale_factor 39 | if self.threshold is not None: 40 | self.loss_scale = max(self.loss_scale, self.threshold) 41 | 42 | def check_overflow(self, grad_norm): 43 | # detect inf and nan 44 | if grad_norm == float("inf") or grad_norm != grad_norm: 45 | # overflow has occured 46 | prev_scale = self.loss_scale 47 | iter_since_rescale = self._iter - self._last_rescale_iter 48 | 49 | self._last_overflow_iter = self._iter 50 | self._overflows_since_rescale += 1 51 | pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) 52 | if pct_overflow >= self.tolerance: 53 | self._decrease_loss_scale() 54 | self._last_rescale_iter = self._iter 55 | self._overflows_since_rescale = 0 56 | 57 | if self.loss_scale <= self.min_loss_scale: 58 | # Use FloatingPointError as an uncommon error that parent 59 | # functions can safely catch to stop training. 60 | self.loss_scale = prev_scale 61 | raise FloatingPointError( 62 | ( 63 | "Minimum loss scale reached ({}). Your loss is probably exploding. " 64 | "Try lowering the learning rate, using gradient clipping or " 65 | "increasing the batch size." 66 | ).format(self.min_loss_scale) 67 | ) 68 | 69 | self._iter += 1 70 | raise OverflowError("setting loss scale to: " + str(self.loss_scale)) 71 | -------------------------------------------------------------------------------- /codes/fairseq/optim/fused_lamb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.optim import LegacyFairseqOptimizer, register_optimizer 7 | 8 | 9 | @register_optimizer("lamb") 10 | class FairseqLAMB(LegacyFairseqOptimizer): 11 | """LAMB optimizer.""" 12 | 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | try: 16 | from apex.optimizers import FusedLAMB 17 | 18 | self._optimizer = FusedLAMB(params, **self.optimizer_config) 19 | except ImportError: 20 | raise ImportError("Please install apex to use LAMB optimizer") 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add optimizer-specific arguments to the parser.""" 25 | # fmt: off 26 | parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', 27 | help='betas for LAMB optimizer') 28 | parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', 29 | help='epsilon for LAMB optimizer') 30 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 31 | help='weight decay') 32 | # fmt: on 33 | 34 | @property 35 | def optimizer_config(self): 36 | """ 37 | Return a kwarg dictionary that will be used to override optimizer 38 | args stored in checkpoints. This allows us to load a checkpoint and 39 | resume training using a different set of optimizer args, e.g., with a 40 | different learning rate. 41 | """ 42 | return { 43 | "lr": self.args.lr[0], 44 | "betas": eval(self.args.lamb_betas), 45 | "eps": self.args.lamb_eps, 46 | "weight_decay": self.args.weight_decay, 47 | } 48 | 49 | @property 50 | def supports_flat_params(self): 51 | return False 52 | -------------------------------------------------------------------------------- /codes/fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | from argparse import Namespace 10 | from typing import Union 11 | 12 | from fairseq import registry 13 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa 14 | FairseqLRScheduler, 15 | LegacyFairseqLRScheduler, 16 | ) 17 | from omegaconf import DictConfig 18 | 19 | 20 | ( 21 | build_lr_scheduler_, 22 | register_lr_scheduler, 23 | LR_SCHEDULER_REGISTRY, 24 | LR_SCHEDULER_DATACLASS_REGISTRY, 25 | ) = registry.setup_registry( 26 | "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" 27 | ) 28 | 29 | 30 | def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): 31 | return build_lr_scheduler_(lr_scheduler_cfg, optimizer) 32 | 33 | 34 | # automatically import any Python files in the optim/lr_scheduler/ directory 35 | for file in os.listdir(os.path.dirname(__file__)): 36 | if file.endswith(".py") and not file.startswith("_"): 37 | file_name = file[: file.find(".py")] 38 | importlib.import_module("fairseq.optim.lr_scheduler." + file_name) 39 | -------------------------------------------------------------------------------- /codes/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from argparse import Namespace 7 | 8 | from fairseq.dataclass.utils import gen_parser_from_dataclass 9 | 10 | from .. import FairseqOptimizer 11 | 12 | 13 | class FairseqLRScheduler(object): 14 | def __init__(self, args, optimizer): 15 | super().__init__() 16 | if not isinstance(optimizer, FairseqOptimizer): 17 | raise ValueError("optimizer must be an instance of FairseqOptimizer") 18 | self.args = args 19 | self.optimizer = optimizer 20 | self.best = None 21 | 22 | @classmethod 23 | def add_args(cls, parser): 24 | """Add arguments to the parser for this LR scheduler.""" 25 | dc = getattr(cls, "__dataclass", None) 26 | if dc is not None: 27 | gen_parser_from_dataclass(parser, dc()) 28 | 29 | def state_dict(self): 30 | """Return the LR scheduler state dict.""" 31 | return {"best": self.best} 32 | 33 | def load_state_dict(self, state_dict): 34 | """Load an LR scheduler state dict.""" 35 | self.best = state_dict["best"] 36 | 37 | def step_begin_epoch(self, epoch): 38 | """Update the learning rate at the beginning of the given epoch.""" 39 | pass 40 | 41 | def step(self, epoch, val_loss=None): 42 | """Update the learning rate at the end of the given epoch.""" 43 | if val_loss is not None: 44 | if self.best is None: 45 | self.best = val_loss 46 | else: 47 | self.best = min(self.best, val_loss) 48 | 49 | def step_update(self, num_updates): 50 | """Update the learning rate after each update.""" 51 | return self.optimizer.get_lr() 52 | 53 | 54 | class LegacyFairseqLRScheduler(FairseqLRScheduler): 55 | def __init__(self, args: Namespace, optimizer): 56 | if not isinstance(optimizer, FairseqOptimizer): 57 | raise ValueError("optimizer must be an instance of FairseqOptimizer") 58 | self.args = args 59 | self.optimizer = optimizer 60 | self.best = None 61 | -------------------------------------------------------------------------------- /codes/fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import LegacyFairseqLRScheduler, register_lr_scheduler 7 | 8 | 9 | @register_lr_scheduler("fixed") 10 | class FixedSchedule(LegacyFairseqLRScheduler): 11 | """Decay the LR on a fixed schedule.""" 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__(args, optimizer) 15 | 16 | # set defaults 17 | args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 18 | 19 | self.lr = args.lr[0] 20 | if args.warmup_updates > 0: 21 | self.warmup_factor = 1.0 / args.warmup_updates 22 | else: 23 | self.warmup_factor = 1 24 | 25 | @staticmethod 26 | def add_args(parser): 27 | """Add arguments to the parser for this LR scheduler.""" 28 | # fmt: off 29 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 30 | help='force annealing at specified epoch (epochs start at 1)') 31 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 32 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)') 33 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 34 | help='warmup the learning rate linearly for the first N updates') 35 | # fmt: on 36 | 37 | def state_dict(self): 38 | return {"lr": self.lr} 39 | 40 | def load_state_dict(self, state_dict): 41 | if "lr" in state_dict: 42 | self.lr = state_dict["lr"] 43 | 44 | def get_next_lr(self, epoch): 45 | lrs = self.args.lr 46 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 47 | # use fixed LR schedule 48 | next_lr = lrs[min(epoch - 1, len(lrs) - 1)] 49 | else: 50 | # annneal based on lr_shrink 51 | next_lr = lrs[-1] * self.args.lr_shrink ** ( 52 | epoch + 1 - self.args.force_anneal 53 | ) 54 | return next_lr 55 | 56 | def step_begin_epoch(self, epoch): 57 | """Update the learning rate at the beginning of the given epoch.""" 58 | self.lr = self.get_next_lr(epoch) 59 | self.optimizer.set_lr(self.warmup_factor * self.lr) 60 | return self.optimizer.get_lr() 61 | 62 | def step_update(self, num_updates): 63 | """Update the learning rate after each update.""" 64 | if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: 65 | self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) 66 | self.optimizer.set_lr(self.warmup_factor * self.lr) 67 | else: 68 | self.optimizer.set_lr(self.lr) 69 | return self.optimizer.get_lr() 70 | -------------------------------------------------------------------------------- /codes/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | from . import LegacyFairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler("triangular") 12 | class TriangularSchedule(LegacyFairseqLRScheduler): 13 | """Assign LR based on a triangular cyclical schedule. 14 | 15 | See https://arxiv.org/pdf/1506.01186.pdf for details. 16 | """ 17 | 18 | def __init__(self, args, optimizer): 19 | super().__init__(args, optimizer) 20 | if len(args.lr) > 1: 21 | raise ValueError( 22 | "Cannot use a fixed learning rate schedule with triangular." 23 | " Consider --lr-scheduler=fixed instead." 24 | ) 25 | 26 | lr = args.lr[0] 27 | 28 | assert args.max_lr > lr, "max_lr must be more than lr" 29 | self.min_lr = lr 30 | self.max_lr = args.max_lr 31 | self.stepsize = args.lr_period_updates // 2 32 | self.lr_shrink = args.lr_shrink 33 | self.shrink_min = args.shrink_min 34 | 35 | # initial learning rate 36 | self.lr = self.min_lr 37 | self.optimizer.set_lr(self.lr) 38 | 39 | @staticmethod 40 | def add_args(parser): 41 | """Add arguments to the parser for this LR scheduler.""" 42 | # fmt: off 43 | parser.add_argument('--max-lr', required=True, type=float, metavar='LR', 44 | help='max learning rate, must be more than args.lr') 45 | parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', 46 | help='initial number of updates per period (cycle length)') 47 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 48 | help='shrink factor for annealing') 49 | parser.add_argument('--shrink-min', action='store_true', 50 | help='if set, also shrinks min lr') 51 | # fmt: on 52 | 53 | def step(self, epoch, val_loss=None): 54 | """Update the learning rate at the end of the given epoch.""" 55 | super().step(epoch, val_loss) 56 | # we don't change the learning rate at epoch boundaries 57 | return self.optimizer.get_lr() 58 | 59 | def step_update(self, num_updates): 60 | """Update the learning rate after each update.""" 61 | cycle = math.floor(num_updates / (2 * self.stepsize)) 62 | 63 | lr_shrink = self.lr_shrink ** cycle 64 | max_lr = self.max_lr * lr_shrink 65 | if self.shrink_min: 66 | min_lr = self.min_lr * lr_shrink 67 | else: 68 | min_lr = self.min_lr 69 | 70 | x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) 71 | self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) 72 | 73 | self.optimizer.set_lr(self.lr) 74 | return self.lr 75 | -------------------------------------------------------------------------------- /codes/fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import LegacyFairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer("sgd") 12 | class SGD(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M', 22 | help='momentum factor') 23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 24 | help='weight decay') 25 | # fmt: on 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | "lr": self.args.lr[0], 37 | "momentum": self.args.momentum, 38 | "weight_decay": self.args.weight_decay, 39 | } 40 | 41 | @property 42 | def supports_flat_params(self): 43 | return True 44 | -------------------------------------------------------------------------------- /codes/fairseq/optim/shard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | try: 8 | from fairscale.optim import OSS 9 | 10 | _has_fairscale = True 11 | except ImportError: 12 | _has_fairscale = False 13 | 14 | 15 | def shard_(args, optimizer, group): 16 | if not _has_fairscale: 17 | raise ImportError( 18 | "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" 19 | ) 20 | 21 | class FairseqOSS(OSS): 22 | @property 23 | def disable_mem_eff_fp16_loading_hack(self): 24 | return True 25 | 26 | def __getattr__(self, name): 27 | if name.startswith("supports") and hasattr(self.optim, name): 28 | return getattr(self.optim, name) 29 | raise AttributeError( 30 | "'FairseqOSS' object has no attribute {0!r}".format(name) 31 | ) 32 | 33 | torch_optimizer = optimizer.optimizer 34 | optim_cls = type(torch_optimizer) 35 | 36 | optimizer.optimizer = FairseqOSS( 37 | torch_optimizer.param_groups, 38 | optim_cls, 39 | group=group, 40 | **optimizer.optimizer_config 41 | ) 42 | -------------------------------------------------------------------------------- /codes/fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import pdb 9 | import sys 10 | 11 | 12 | __all__ = ["set_trace"] 13 | 14 | 15 | _stdin = [None] 16 | _stdin_lock = multiprocessing.Lock() 17 | try: 18 | _stdin_fd = sys.stdin.fileno() 19 | except Exception: 20 | _stdin_fd = None 21 | 22 | 23 | class MultiprocessingPdb(pdb.Pdb): 24 | """A Pdb wrapper that works in a multiprocessing environment. 25 | 26 | Usage: `from fairseq import pdb; pdb.set_trace()` 27 | """ 28 | 29 | def __init__(self): 30 | pdb.Pdb.__init__(self, nosigint=True) 31 | 32 | def _cmdloop(self): 33 | stdin_bak = sys.stdin 34 | with _stdin_lock: 35 | try: 36 | if _stdin_fd is not None: 37 | if not _stdin[0]: 38 | _stdin[0] = os.fdopen(_stdin_fd) 39 | sys.stdin = _stdin[0] 40 | self.cmdloop() 41 | finally: 42 | sys.stdin = stdin_bak 43 | 44 | 45 | def set_trace(): 46 | pdb = MultiprocessingPdb() 47 | pdb.set_trace(sys._getframe().f_back) 48 | -------------------------------------------------------------------------------- /codes/fairseq/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | from abc import ABC, abstractmethod 10 | 11 | from fairseq import registry 12 | 13 | 14 | class BaseScorer(ABC): 15 | def __init__(self, args): 16 | self.args = args 17 | self.ref = [] 18 | self.pred = [] 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | pass 23 | 24 | def add_string(self, ref, pred): 25 | self.ref.append(ref) 26 | self.pred.append(pred) 27 | 28 | @abstractmethod 29 | def score(self) -> float: 30 | pass 31 | 32 | @abstractmethod 33 | def result_string(self) -> str: 34 | pass 35 | 36 | 37 | _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( 38 | "--scoring", default="bleu" 39 | ) 40 | 41 | 42 | def build_scorer(args, tgt_dict): 43 | from fairseq import utils 44 | 45 | if args.sacrebleu: 46 | utils.deprecation_warning( 47 | "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." 48 | ) 49 | args.scoring = "sacrebleu" 50 | if args.scoring == "bleu": 51 | from fairseq.scoring import bleu 52 | 53 | return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) 54 | return _build_scorer(args) 55 | 56 | 57 | # automatically import any Python files in the current directory 58 | for file in os.listdir(os.path.dirname(__file__)): 59 | if file.endswith(".py") and not file.startswith("_"): 60 | module = file[: file.find(".py")] 61 | importlib.import_module("fairseq.scoring." + module) 62 | -------------------------------------------------------------------------------- /codes/fairseq/scoring/chrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.scoring import BaseScorer, register_scorer 7 | 8 | 9 | @register_scorer("chrf") 10 | class ChrFScorer(BaseScorer): 11 | def __init__(self, args): 12 | super(ChrFScorer, self).__init__(args) 13 | import sacrebleu 14 | 15 | self.sacrebleu = sacrebleu 16 | 17 | def add_string(self, ref, pred): 18 | self.ref.append(ref) 19 | self.pred.append(pred) 20 | 21 | def score(self, order=4): 22 | return self.result_string(order).score 23 | 24 | def result_string(self, order=4): 25 | if order != 4: 26 | raise NotImplementedError 27 | return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format() 28 | -------------------------------------------------------------------------------- /codes/fairseq/scoring/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unicodedata 7 | 8 | 9 | class EvaluationTokenizer(object): 10 | """A generic evaluation-time tokenizer, which leverages built-in tokenizers 11 | in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides 12 | lowercasing, punctuation removal and character tokenization, which are 13 | applied after sacreBLEU tokenization. 14 | 15 | Args: 16 | tokenizer_type (str): the type of sacreBLEU tokenizer to apply. 17 | lowercase (bool): lowercase the text. 18 | punctuation_removal (bool): remove punctuation (based on unicode 19 | category) from text. 20 | character_tokenization (bool): tokenize the text to characters. 21 | """ 22 | 23 | SPACE = chr(32) 24 | SPACE_ESCAPE = chr(9601) 25 | ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] 26 | 27 | def __init__( 28 | self, 29 | tokenizer_type: str = "13a", 30 | lowercase: bool = False, 31 | punctuation_removal: bool = False, 32 | character_tokenization: bool = False, 33 | ): 34 | from sacrebleu.tokenizers import TOKENIZERS 35 | 36 | assert tokenizer_type in self.ALL_TOKENIZER_TYPES 37 | self.lowercase = lowercase 38 | self.punctuation_removal = punctuation_removal 39 | self.character_tokenization = character_tokenization 40 | self.tokenizer = TOKENIZERS[tokenizer_type] 41 | 42 | @classmethod 43 | def remove_punctuation(cls, sent: str): 44 | """Remove punctuation based on Unicode category.""" 45 | return cls.SPACE.join( 46 | t 47 | for t in sent.split(cls.SPACE) 48 | if not all(unicodedata.category(c)[0] == "P" for c in t) 49 | ) 50 | 51 | def tokenize(self, sent: str): 52 | tokenized = self.tokenizer()(sent) 53 | 54 | if self.punctuation_removal: 55 | tokenized = self.remove_punctuation(tokenized) 56 | 57 | if self.character_tokenization: 58 | tokenized = self.SPACE.join( 59 | list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) 60 | ) 61 | 62 | if self.lowercase: 63 | tokenized = tokenized.lower() 64 | 65 | return tokenized 66 | -------------------------------------------------------------------------------- /codes/fairseq/scoring/wer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.scoring import BaseScorer, register_scorer 7 | from fairseq.scoring.tokenizer import EvaluationTokenizer 8 | 9 | 10 | @register_scorer("wer") 11 | class WerScorer(BaseScorer): 12 | def __init__(self, args): 13 | super().__init__(args) 14 | self.reset() 15 | try: 16 | import editdistance as ed 17 | except ImportError: 18 | raise ImportError("Please install editdistance to use WER scorer") 19 | self.ed = ed 20 | self.tokenizer = EvaluationTokenizer( 21 | tokenizer_type=self.args.wer_tokenizer, 22 | lowercase=self.args.wer_lowercase, 23 | punctuation_removal=self.args.wer_remove_punct, 24 | character_tokenization=self.args.wer_char_level, 25 | ) 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | # fmt: off 30 | parser.add_argument('--wer-tokenizer', type=str, default='none', 31 | choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, 32 | help='sacreBLEU tokenizer to use for evaluation') 33 | parser.add_argument('--wer-remove-punct', action='store_true', 34 | help='remove punctuation') 35 | parser.add_argument('--wer-char-level', action='store_true', 36 | help='evaluate at character level') 37 | parser.add_argument('--wer-lowercase', action='store_true', 38 | help='lowercasing') 39 | # fmt: on 40 | 41 | def reset(self): 42 | self.distance = 0 43 | self.ref_length = 0 44 | 45 | def add_string(self, ref, pred): 46 | ref_items = self.tokenizer.tokenize(ref).split() 47 | pred_items = self.tokenizer.tokenize(pred).split() 48 | self.distance += self.ed.eval(ref_items, pred_items) 49 | self.ref_length += len(ref_items) 50 | 51 | def result_string(self): 52 | return f"WER: {self.score():.2f}" 53 | 54 | def score(self): 55 | return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 56 | -------------------------------------------------------------------------------- /codes/fairseq/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary 7 | from fairseq.tasks.translation import TranslationTask 8 | 9 | from . import register_task 10 | 11 | 12 | @register_task("translation_from_pretrained_xlm") 13 | class TranslationFromPretrainedXLMTask(TranslationTask): 14 | """ 15 | Same as TranslationTask except use the MaskedLMDictionary class so that 16 | we can load data that was binarized with the MaskedLMDictionary class. 17 | 18 | This task should be used for the entire training pipeline when we want to 19 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 20 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 21 | of that trained model. 22 | """ 23 | 24 | @classmethod 25 | def load_dictionary(cls, filename): 26 | """Load the masked LM dictionary from the filename 27 | 28 | Args: 29 | filename (str): the filename 30 | """ 31 | return MaskedLMDictionary.load(filename) 32 | -------------------------------------------------------------------------------- /codes/fairseq/token_block_utils_fast.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/token_block_utils_fast.cpython-36m-darwin.so -------------------------------------------------------------------------------- /codes/fairseq/token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /codes/fairseq/token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq/token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /codes/fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | 9 | SPACE_NORMALIZER = re.compile(r"\s+") 10 | 11 | 12 | def tokenize_line(line): 13 | line = SPACE_NORMALIZER.sub(" ", line) 14 | line = line.strip() 15 | return line.split() 16 | -------------------------------------------------------------------------------- /codes/fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderseen/PCKMT/2a50a85f671ac46252b1038f62abdded367aeb46/codes/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /codes/test_adaptive_knn_mt_knn_align.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | DOMAIN=it # it medical koran law subtitles 4 | k=4 5 | loss=1.82 6 | GPU=0 7 | BATCH_SIZE=64 8 | 9 | 10 | # datastore 11 | declare -A DSTORE_SIZES_dict 12 | DSTORE_SIZES_dict=([law]="19061383" [subtitles]="153604142" [it]="3602863" [medical]="6903142" [koran]="524375") 13 | DSTORE_SIZE=${DSTORE_SIZES_dict[$DOMAIN]} 14 | DATASTORE_PATH=save_datastore/${DOMAIN}/knn_transfered$postfix 15 | 16 | 17 | # Our codes support multi-gpu faiss, but if your single-/multi- gpu memories in total are less than 20GB, 18 | # please set USE_GPU=False for datastores (>20M) to build the cpu-mode faiss instance to avoid OOM. 19 | USE_GPU_dict=([it]="True" [medical]="True" [koran]="True" [law]="True" [merge]="True" [subtitles]="True") 20 | USE_GPU=${USE_GPU_dict[$DOMAIN]} 21 | 22 | 23 | # model 24 | MODEL_PATH_SUFFIX=model_record_path/${DOMAIN}/train-hid32-maxk$k${postfix} 25 | MODEL_PATH=${MODEL_PATH_SUFFIX}/checkpoint.best_loss_${loss}.pt 26 | 27 | 28 | # corpus 29 | DATA_PATH=data-bin/${DOMAIN} 30 | 31 | 32 | # translation output 33 | OUTPUT_PATH=${MODEL_PATH_SUFFIX}/results 34 | mkdir -p "$OUTPUT_PATH" 35 | 36 | PROJECT_PATH=. 37 | CUDA_VISIBLE_DEVICES=$GPU python $PROJECT_PATH/experimental_knn_align.py $DATA_PATH \ 38 | --gen-subset test \ 39 | --path "$MODEL_PATH" --arch transformer_knn_de_en_transfered_by_distance \ 40 | --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \ 41 | --scoring sacrebleu \ 42 | --batch-size $BATCH_SIZE \ 43 | --tokenizer moses --remove-bpe --not-train-knn-compact-projection \ 44 | --model-overrides "{ 45 | 'batch_size': $BATCH_SIZE, 46 | 'load_knn_datastore': True, 47 | 'use_knn_datastore': True, 48 | 'dstore_filename': '$DATASTORE_PATH', 49 | 'dstore_size': $DSTORE_SIZE, 50 | 'dstore_fp16': True, 'probe': 32, 51 | 'knn_sim_func': 'do_not_recomp_l2', 52 | 'use_gpu_to_search': ${USE_GPU}, 53 | 'move_dstore_to_mem': True, 'no_load_keys': True, 54 | 'knn_temperature_type': 'fix', 55 | 'knn_temperature_value': 10, 56 | 'k_lambda_net_dropout_rate': 0.0, 57 | 'only_train_knn_parameter': False 58 | }" \ 59 | | tee "$OUTPUT_PATH"/generate.txt 60 | 61 | grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src 62 | grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref 63 | grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp 64 | grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok 65 | -------------------------------------------------------------------------------- /codes/test_adaptive_knn_mt_knn_align_pruned.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | DOMAIN=it # it medical koran law subtitles 4 | DATASTORE_DOMAIN=${DOMAIN} 5 | k=4 6 | loss=1.84 7 | GPU=0 8 | BATCH_SIZE=64 9 | PRUNE_METHOD=/prune_similar_ppl-0.9 10 | 11 | declare -A DSTORE_SIZES_dict 12 | DSTORE_SIZES_dict=([subtitles]="" [medical]="" [koran]="" [it]="" [law]="" [merge]="" ) 13 | MODEL_PATH_SUFFIX=model_record_path/... 14 | DATASTORE_PATH=save_datastore/... 15 | DSTORE_SIZE=${DSTORE_SIZES_dict[${DATASTORE_DOMAIN}]} 16 | MODEL_PATH=${MODEL_PATH_SUFFIX}/checkpoint.best_loss_${loss}.pt 17 | OUTPUT_PATH=${MODEL_PATH_SUFFIX}/results 18 | PROJECT_PATH=. 19 | DATA_PATH=data-bin/${DOMAIN} 20 | 21 | mkdir -p "$OUTPUT_PATH" 22 | 23 | CUDA_VISIBLE_DEVICES=$GPU python $PROJECT_PATH/experimental_knn_align.py $DATA_PATH \ 24 | --gen-subset test \ 25 | --path "$MODEL_PATH" --arch transformer_knn_de_en_transfered_by_distance \ 26 | --beam 4 --lenpen 0.6 --max-len-a 1.2 \ 27 | --max-len-b 10 --source-lang de --target-lang en \ 28 | --scoring sacrebleu \ 29 | --batch-size $BATCH_SIZE \ 30 | --tokenizer moses --remove-bpe --not-train-knn-compact-projection \ 31 | --model-overrides "{ 32 | 'batch_size': $BATCH_SIZE, 33 | 'load_knn_datastore': True, 34 | 'use_knn_datastore': True, 35 | 'dstore_filename': '$DATASTORE_PATH', 36 | 'dstore_size': $DSTORE_SIZE, 37 | 'dstore_fp16': True, 38 | 'probe':32, 39 | 'knn_sim_func': 'do_not_recomp_l2', 40 | 'use_gpu_to_search': True, 41 | 'move_dstore_to_mem': True, 'no_load_keys': True, 42 | 'knn_temperature_type': 'fix', 43 | 'knn_temperature_value': 10, 44 | 'k_lambda_net_dropout_rate': 0.0, 45 | 'only_train_knn_parameter': False 46 | }" \ 47 | | tee "$OUTPUT_PATH"/generate.txt 48 | 49 | grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src 50 | grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref 51 | grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp 52 | grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok 53 | -------------------------------------------------------------------------------- /codes/train_faiss_knn_align_pruned.sh: -------------------------------------------------------------------------------- 1 | DOMAINS=(subtitles law koran medical it) 2 | DSTORE_SIZES=(...) # according to your pruned rate 3 | 4 | BASE_DATASTORE_PATH=save_datastore 5 | PROJECT_PATH=. 6 | 7 | max_k_grid=(4 4 4 4) 8 | batch_size_grid=(16 16 16 16 16) # according to your device memory 9 | update_freq_grid=(1 1 1 1) 10 | valid_batch_size_grid=(16 16 16 16 16) 11 | gpu_ids=(0) 12 | 13 | USE_GPU_dict=([it]="True" [medical]="True" [koran]="True" [law]="True" [merge]="True" [subtitles]="False") 14 | 15 | for idx in ${!gpu_ids[*]} 16 | do 17 | 18 | USE_GPU=${USE_GPU_dict[${DOMAINS[$idx]}]} 19 | 20 | DATASTORE_PATH=${BASE_DATASTORE_PATH}/... 21 | MODEL_PATH=model_record_path/${DOMAINS[$idx]}/... 22 | 23 | DATA_PATH=data-bin/${DOMAINS[$idx]} 24 | MODEL_RECORD_PATH=model_record_path/... 25 | TRAINING_RECORD_PATH=model_record_tensorboard_path/... 26 | 27 | 28 | rm -rf "$MODEL_RECORD_PATH" 29 | rm -rf "$TRAINING_RECORD_PATH" 30 | mkdir -p "$MODEL_RECORD_PATH" 31 | mkdir -p "$TRAINING_RECORD_PATH" 32 | 33 | CUDA_VISIBLE_DEVICES=${gpu_ids[$idx]} python \ 34 | $PROJECT_PATH/fairseq_cli/train.py \ 35 | $DATA_PATH \ 36 | --use-gpu-to-search \ 37 | --log-interval 100 --log-format simple \ 38 | --arch transformer_knn_de_en_transfered_by_distance \ 39 | --tensorboard-logdir "$TRAINING_RECORD_PATH" \ 40 | --save-dir "$MODEL_RECORD_PATH" --restore-file "$MODEL_PATH" \ 41 | --reset-dataloader --reset-lr-scheduler --reset-meters --reset-optimizer \ 42 | --validate-interval-updates 100 --save-interval-updates 100 --keep-interval-updates 1 \ 43 | --max-update 50000 --validate-after-updates 1000 \ 44 | --save-interval 2000 --validate-interval 100 \ 45 | --keep-best-checkpoints 1 --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \ 46 | --train-subset valid --valid-subset valid --source-lang de --target-lang en \ 47 | --criterion label_smoothed_cross_entropy --label-smoothing 0.001 \ 48 | --max-source-positions 1024 --max-target-positions 1024 \ 49 | --batch-size "${batch_size_grid[$idx]}" --update-freq "${update_freq_grid[$idx]}" \ 50 | --batch-size-valid "${valid_batch_size_grid[$idx]}" \ 51 | --task translation \ 52 | --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \ 53 | --min-lr 3e-05 --lr 0.0003 --clip-norm 1.0 \ 54 | --lr-scheduler reduce_lr_on_plateau --lr-patience 5 --lr-shrink 0.5 \ 55 | --patience 300 \ 56 | --max-epoch 5000 \ 57 | --load-knn-datastore --dstore-filename $DATASTORE_PATH --use-knn-datastore \ 58 | --dstore-size "${DSTORE_SIZES[$idx]}" --probe 32 \ 59 | --knn-sim-func do_not_recomp_l2 --no-load-keys \ 60 | --move-dstore-to-mem \ 61 | --knn-lambda-type trainable --knn-temperature-type fix \ 62 | --only-train-knn-parameter --knn-k-type trainable \ 63 | --k-lambda-net-hid-size 32 --k-lambda-net-dropout-rate 0.0 \ 64 | --max-k "${max_k_grid[$idx]}" --k "${max_k_grid[$idx]}" \ 65 | --label-count-as-feature --not-train-knn-compact-projection \ 66 | --dstore-fp16 \ 67 | --knn-temperature-value 10 68 | done 69 | --------------------------------------------------------------------------------