├── README.md ├── SimileEMNLP.csv ├── autoeval.py ├── comet-commonsense ├── README.md ├── config │ ├── atomic │ │ ├── changes.json │ │ ├── default.json │ │ └── eval_changes.json │ ├── conceptnet │ │ ├── changes.json │ │ ├── default.json │ │ └── eval_changes.json │ └── default.json ├── directory.md ├── parameters_names.json ├── scripts │ ├── classify │ │ ├── classify.sh │ │ ├── classify_conceptnet_generations.py │ │ ├── convert_conceptnet_generations_to_text.py │ │ └── demo_bilinear.py │ ├── data │ │ ├── make_atomic_data_loader.py │ │ └── make_conceptnet_data_loader.py │ ├── evaluate │ │ ├── bleu_atomic.py │ │ └── evaluate_atomic_generation_model.py │ ├── generate │ │ ├── generate_atomic_beam_search.py │ │ ├── generate_atomic_greedy.py │ │ ├── generate_atomic_topk.py │ │ ├── generate_conceptnet_arbitrary.py │ │ └── generate_conceptnet_beam_search.py │ ├── interactive │ │ ├── atomic_single_example.py │ │ └── conceptnet_single_example.py │ ├── novelty │ │ └── compute_conceptnet_novelty.py │ └── setup │ │ ├── get_atomic_data.sh │ │ ├── get_conceptnet_data.sh │ │ └── get_model_files.sh ├── src │ ├── data │ │ ├── atomic.py │ │ ├── conceptnet.py │ │ ├── config.py │ │ ├── data.py │ │ └── utils.py │ ├── evaluate │ │ ├── atomic_evaluate.py │ │ ├── conceptnet_evaluate.py │ │ ├── conceptnet_generate.py │ │ ├── evaluate.py │ │ ├── generate.py │ │ ├── sampler.py │ │ └── utils.py │ ├── interactive │ │ └── functions.py │ ├── main.py │ ├── main_atomic.py │ ├── main_conceptnet.py │ ├── models │ │ ├── gpt.py │ │ ├── models.py │ │ └── utils.py │ └── train │ │ ├── atomic_train.py │ │ ├── batch.py │ │ ├── conceptnet_train.py │ │ ├── opt.py │ │ ├── train.py │ │ └── utils.py └── utils │ └── utils.py ├── convert_to_literal.py ├── create_bpe.sh ├── dict.txt ├── encoder.json ├── fairseq.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── entry_points.txt ├── not-zip-safe ├── requires.txt └── top_level.txt ├── fairseq ├── __init__.py ├── binarizer.py ├── bleu.py ├── checkpoint_utils.py ├── clib │ ├── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp │ ├── libnat │ │ └── edit_dist.cpp │ └── libnat_cuda │ │ ├── binding.cpp │ │ ├── edit_dist.cu │ │ └── edit_dist.h ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── binary_cross_entropy.py │ ├── composite_loss.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ ├── label_smoothed_cross_entropy.py │ ├── label_smoothed_cross_entropy_with_alignment.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── nat_loss.py │ ├── sentence_prediction.py │ └── sentence_ranking.py ├── data │ ├── __init__.py │ ├── append_token_dataset.py │ ├── audio │ │ ├── __init__.py │ │ └── raw_audio_dataset.py │ ├── backtranslation_dataset.py │ ├── base_wrapper_dataset.py │ ├── colorize_dataset.py │ ├── concat_dataset.py │ ├── concat_sentences_dataset.py │ ├── data_utils.p │ ├── data_utils.py │ ├── data_utils_fast.cpp │ ├── data_utils_fast.cpython-37m-x86_64-linux-gnu.so │ ├── data_utils_fast.pyx │ ├── denoising_dataset.py │ ├── dictionary.py │ ├── encoders │ │ ├── __init__.py │ │ ├── fastbpe.py │ │ ├── gpt2_bpe.py │ │ ├── gpt2_bpe_utils.py │ │ ├── hf_bert_bpe.py │ │ ├── moses_tokenizer.py │ │ ├── nltk_tokenizer.py │ │ ├── sentencepiece_bpe.py │ │ ├── space_tokenizer.py │ │ ├── subword_nmt_bpe.py │ │ └── utils.py │ ├── fairseq_dataset.py │ ├── id_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── legacy │ │ ├── __init__.py │ │ ├── block_pair_dataset.py │ │ ├── masked_lm_dataset.py │ │ └── masked_lm_dictionary.py │ ├── list_dataset.py │ ├── lm_context_window_dataset.py │ ├── lru_cache_dataset.py │ ├── mask_tokens_dataset.py │ ├── monolingual_dataset.py │ ├── multi_corpus_sampled_dataset.py │ ├── nested_dictionary_dataset.py │ ├── noising.py │ ├── num_samples_dataset.py │ ├── numel_dataset.py │ ├── offset_tokens_dataset.py │ ├── pad_dataset.py │ ├── plasma_utils.py │ ├── prepend_dataset.py │ ├── prepend_token_dataset.py │ ├── raw_label_dataset.py │ ├── replace_dataset.py │ ├── resampling_dataset.py │ ├── roll_dataset.py │ ├── round_robin_zip_datasets.py │ ├── sharded_dataset.py │ ├── sort_dataset.py │ ├── strip_token_dataset.py │ ├── subsample_dataset.py │ ├── token_block_dataset.py │ ├── token_block_utils_fast.cpp │ ├── token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so │ ├── token_block_utils_fast.pyx │ ├── transform_eos_dataset.py │ ├── transform_eos_lang_pair_dataset.py │ └── truncate_dataset.py ├── dict.txt ├── distributed_utils.py ├── encoder.json ├── file_io.py ├── file_utils.py ├── hub_utils.py ├── iterative_refinement_generator.py ├── legacy_distributed_data_parallel.py ├── libbleu.cpython-37m-x86_64-linux-gnu.so ├── libnat.cpython-37m-x86_64-linux-gnu.so ├── meters.py ├── metrics.py ├── models │ ├── __init__.py │ ├── bart │ │ ├── __init__.py │ │ ├── hub_interface.py │ │ ├── hub_interface1.py │ │ └── model.py │ ├── composite_encoder.py │ ├── distributed_fairseq_model.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_lm.py │ ├── fconv_self_att.py │ ├── lightconv.py │ ├── lightconv_lm.py │ ├── lstm.py │ ├── masked_lm.py │ ├── model_utils.py │ ├── multilingual_transformer.py │ ├── nat │ │ ├── __init__.py │ │ ├── cmlm_transformer.py │ │ ├── fairseq_nat_model.py │ │ ├── insertion_transformer.py │ │ ├── iterative_nonautoregressive_transformer.py │ │ ├── levenshtein_transformer.py │ │ ├── levenshtein_utils.py │ │ ├── nat_crf_transformer.py │ │ ├── nonautoregressive_ensembles.py │ │ └── nonautoregressive_transformer.py │ ├── roberta │ │ ├── __init__.py │ │ ├── alignment_utils.py │ │ ├── hub_interface.py │ │ └── model.py │ ├── transformer.py │ ├── transformer_from_pretrained_xlm.py │ ├── transformer_lm.py │ └── wav2vec.py ├── modules │ ├── __init__.py │ ├── adaptive_input.py │ ├── adaptive_softmax.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── cuda_utils.cu │ ├── downsampled_multihead_attention.py │ ├── dynamic_convolution.py │ ├── dynamic_crf_layer.py │ ├── dynamicconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── dynamicconv_cuda.cpp │ │ ├── dynamicconv_cuda.cuh │ │ ├── dynamicconv_cuda_kernel.cu │ │ ├── dynamicconv_layer.py │ │ ├── dynamiconv_cpu.cpp │ │ └── setup.py │ ├── gelu.py │ ├── grad_multiply.py │ ├── highway.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── lightconv_cuda.cpp │ │ ├── lightconv_cuda.cuh │ │ ├── lightconv_cuda_kernel.cu │ │ ├── lightconv_layer.py │ │ └── setup.py │ ├── lightweight_convolution.py │ ├── linearized_convolution.py │ ├── logsumexp_moe.py │ ├── mean_pool_gating_network.py │ ├── multihead_attention.py │ ├── positional_embedding.py │ ├── scalar_bias.py │ ├── sinusoidal_positional_embedding.py │ ├── sparse_multihead_attention.py │ ├── sparse_transformer_sentence_encoder.py │ ├── sparse_transformer_sentence_encoder_layer.py │ ├── transformer_layer.py │ ├── transformer_sentence_encoder.py │ ├── transformer_sentence_encoder_layer.py │ ├── unfold.py │ └── vggblock.py ├── optim │ ├── __init__.py │ ├── adadelta.py │ ├── adafactor.py │ ├── adagrad.py │ ├── adam.py │ ├── adamax.py │ ├── bmuf.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── fused_adam.py │ ├── fused_lamb.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── polynomial_decay_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ ├── tri_stage_lr_scheduler.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ └── sgd.py ├── options.py ├── pdb.py ├── progress_bar.py ├── registry.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── audio_pretraining.py │ ├── cross_lingual_lm.py │ ├── denoising.py │ ├── fairseq_task.py │ ├── language_modeling.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── multilingual_masked_lm.py │ ├── multilingual_translation.py │ ├── semisupervised_translation.py │ ├── sentence_prediction.py │ ├── sentence_ranking.py │ ├── translation.py │ ├── translation_from_pretrained_xlm.py │ ├── translation_lev.py │ └── translation_moe.py ├── tokenizer.py ├── trainer.py └── utils1.py ├── fairseq_cli ├── __init__.py ├── generate.py ├── interactive.py ├── preprocess.py └── setup.py ├── finetune.sh ├── gen1.source ├── gen2.source ├── generate.py ├── hubconf.py ├── human_labels.csv ├── interactive.py ├── preprocess.py ├── preprocess.sh ├── scrape_reddit_for_similes.py ├── scripts ├── __init__.py ├── average_checkpoints.py ├── build_sym_alignment.py ├── compare_namespaces.py ├── compound_split_bleu.sh ├── convert_dictionary.lua ├── convert_model.lua ├── count_docs.py ├── read_binarized.py ├── rm_pt.py ├── sacrebleu_pregen.sh ├── shard_docs.py ├── split_train_valid_docs.py ├── spm_decode.py ├── spm_encode.py ├── spm_train.py ├── wav2vec_featurize.py └── wav2vec_manifest.py ├── setup.py ├── tests ├── __init__.py ├── speech_recognition │ ├── __init__.py │ ├── asr_test_base.py │ ├── test_collaters.py │ ├── test_cross_entropy.py │ └── test_vggtransformer.py ├── test_average_checkpoints.py ├── test_backtranslation_dataset.py ├── test_binaries.py ├── test_bmuf.py ├── test_character_token_embedder.py ├── test_concat_dataset.py ├── test_convtbc.py ├── test_dictionary.py ├── test_file_io.py ├── test_iterators.py ├── test_label_smoothing.py ├── test_memory_efficient_fp16.py ├── test_multi_corpus_sampled_dataset.py ├── test_multihead_attention.py ├── test_noising.py ├── test_reproducibility.py ├── test_resampling_dataset.py ├── test_sequence_generator.py ├── test_sequence_scorer.py ├── test_sparse_multihead_attention.py ├── test_token_block_dataset.py ├── test_train.py ├── test_utils.py └── utils.py ├── train.py └── vocab.bpe /autoeval.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from nltk.translate.bleu_score import corpus_bleu 3 | from bert_score import score 4 | import os 5 | os.environ['CUDA_VISIBLE_DEVICES']='1' 6 | 7 | bleu1 = 0.0 8 | bs = 0.0 9 | r = [] 10 | c = [] 11 | with open('./human_labels.csv') as csv_file: 12 | csv_reader = csv.reader(csv_file, delimiter=',') 13 | line_count = 0 14 | for row in csv_reader: 15 | if line_count == 0: 16 | print(f'Column names are {", ".join(row)}') 17 | line_count += 1 18 | else: 19 | reference = [row[0],row[1]] 20 | candidate = [row[2]] 21 | r.append([row[0].split(),row[1].split()]) 22 | c.append([row[2].split()]) 23 | P_mul, R_mul, F_mul = score(candidate, reference, lang="en", rescale_with_baseline=True) 24 | F_mul = F_mul.tolist()[0] 25 | bs = bs+F_mul 26 | print("BLEU1",corpus_bleu(r, c,weights=(1, 0, 0, 0))*100) 27 | print("BLEU2",corpus_bleu(r, c,weights=(0, 1, 0, 0))*100) 28 | print("BERTSCORE",float(bs)/150.0) 29 | -------------------------------------------------------------------------------- /comet-commonsense/config/atomic/changes.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": { 3 | "0": { 4 | "gpu_index": 0 5 | }, 6 | "1": { 7 | "gpu_index": 1 8 | }, 9 | "2": { 10 | "gpu_index": 2 11 | }, 12 | "3": { 13 | "gpu_index": 3 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /comet-commonsense/config/atomic/default.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "dataset": "atomic", 4 | "categories": ["oReact", "oEffect", "oWant", "xAttr", "xEffect", "xIntent", "xNeed", "xReact", "xWant"], 5 | "eval_categories": ["oReact", "oEffect", "oWant", "xAttr", "xEffect", "xIntent", "xNeed", "xReact", "xWant"], 6 | "exp": "generation", 7 | "labels": "individual", 8 | "encoder_path": "model/encoder_bpe_40000.json", 9 | "bpe_path": "model/vocab_40000.bpe", 10 | "batch_size": 64, 11 | "learning_rate_schedule": "warmup_linear", 12 | "learning_rate_warmup": 0.002, 13 | "l2": 0.01, 14 | "vector_l2": "T", 15 | "evaluate_sequences": 10000 16 | } -------------------------------------------------------------------------------- /comet-commonsense/config/atomic/eval_changes.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "base": { 4 | "0": { 5 | "gpu_index": 0, 6 | "generate_sequences": "full", 7 | "evaluate_sequences": "full" 8 | }, 9 | "1": { 10 | "gpu_index": 1, 11 | "generate_sequences": "full", 12 | "evaluate_sequences": "full" 13 | }, 14 | "2": { 15 | "gpu_index": 2, 16 | "generate_sequences": "full", 17 | "evaluate_sequences": "full" 18 | }, 19 | "3": { 20 | "gpu_index": 3, 21 | "generate_sequences": "full", 22 | "evaluate_sequences": "full" 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /comet-commonsense/config/conceptnet/changes.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": { 3 | "0": { 4 | "gpu_index": 0 5 | }, 6 | "1": { 7 | "gpu_index": 1 8 | }, 9 | "2": { 10 | "gpu_index": 2 11 | }, 12 | "3": { 13 | "gpu_index": 3 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /comet-commonsense/config/conceptnet/default.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "dataset": "conceptnet", 4 | "exp": "generation", 5 | "do_gen": "T", 6 | "encoder_path": "model/encoder_bpe_40000.json", 7 | "bpe_path": "model/vocab_40000.bpe", 8 | "batch_size": 64, 9 | "learning_rate_schedule": "warmup_linear", 10 | "learning_rate_warmup": 0.002, 11 | "l2": 0.01, 12 | "vector_l2": "T", 13 | "generate_sequences": "full", 14 | "evaluate_sequences": "full", 15 | "relation_format": "language", 16 | "training_set_size": 100, 17 | "development_set_versions_to_use": "12", 18 | "max_event_1_size": 10, 19 | "max_event_2_size": 15, 20 | "eval_sampler": "greedy", 21 | "iterations": 100000, 22 | "learning_rate": 1e-5 23 | } -------------------------------------------------------------------------------- /comet-commonsense/config/conceptnet/eval_changes.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "base": { 4 | "0": { 5 | "gpu_index": 0 6 | }, 7 | "1": { 8 | "gpu_index": 1 9 | }, 10 | "2": { 11 | "gpu_index": 2 12 | }, 13 | "3": { 14 | "gpu_index": 3 15 | } 16 | } 17 | } -------------------------------------------------------------------------------- /comet-commonsense/config/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu_mode": "T", 3 | "gpu_index": 0, 4 | "gpu_indices": [0, 1], 5 | "multigpu": "F", 6 | 7 | "topk_size": 10, 8 | "beam_size": 1, 9 | "gen_seqlength": 40, 10 | "eval_sampler": "greedy", 11 | "num_sequences": 1, 12 | "generate_sequences": 1000, 13 | "evaluate_sequences": 10000, 14 | 15 | "random_seed": 123, 16 | "optimizer": "adam", 17 | "batch_size": 64, 18 | "learning_rate": 6.25e-5, 19 | 20 | "clip": 1, 21 | "loss": "nll", 22 | "weight_decay": 0, 23 | 24 | "adam": { 25 | "b2": 0.999, 26 | "b1": 0.9, 27 | "e": 1e-8 28 | }, 29 | 30 | "model": "transformer", 31 | "pretrain": "gpt", 32 | "hidden_dim": 768, 33 | "num_layers": 12, 34 | "num_heads": 12, 35 | "embedding_dropout": 0.1, 36 | "attention_dropout": 0.1, 37 | "residual_dropout": 0.1, 38 | "output_dropout": 0.1, 39 | "activation": "gelu", 40 | "init": "pt", 41 | 42 | "trainer": "iteration", 43 | 44 | "iterations": 50000, 45 | "cycle": 500, 46 | 47 | "save_strategy": "best", 48 | 49 | "epochs": 20, 50 | "toy": "F", 51 | "do_gen": "F", 52 | "save": "T", 53 | "test_save": "F" 54 | } -------------------------------------------------------------------------------- /comet-commonsense/scripts/classify/classify.sh: -------------------------------------------------------------------------------- 1 | gens=${1-} 2 | 3 | python scripts/classify/convert_conceptnet_generations_to_text.py --gens_file ${gens}.pickle 4 | python2.7 scripts/classify/classify_conceptnet_generations.py --gens_name ${gens}.txt -------------------------------------------------------------------------------- /comet-commonsense/scripts/classify/classify_conceptnet_generations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import demo_bilinear 5 | 6 | train_file = "data/conceptnet/train100k.txt.gz" 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--gens_name", type=str, default="results/gens/conceptnet-generation/iteration-500-100000/transformer/rel_language-trainsize_100-devversion_12-maxe1_10-maxe2_15/model_transformer-nL_12-nH_12-hSize_768-edpt_0.1-adpt_0.1-rdpt_0.1-odpt_0.1-pt_gpt-afn_gelu-init_pt-vSize_40545/exp_generation-seed_123-l2_0.01-vl2_T-lrsched_warmup_linear-lrwarm_0.002-clip_1-loss_nll-b2_0.999-b1_0.9-e_1e-08/bs_1-smax_40-sample_greedy-numseq_1-gs_full-es_full/1e-05_adam_64_15500/test.txt") 10 | parser.add_argument("--thresh", type=float, default=0.5) 11 | 12 | args = parser.parse_args() 13 | 14 | # print(gens_file[0]) 15 | results = demo_bilinear.run(args.gens_name, flip_r_e1=False) 16 | new_results = {"0": [j for (i, j) in results if i[3] == "0"], 17 | "1": [j for (i, j) in results if i[3] == "1"]} 18 | 19 | print("Total") 20 | num_examples = 1.0 * len(results) 21 | accuracy = (len([i for i in new_results["1"] if i >= args.thresh]) + 22 | len([i for i in new_results["0"] if i < args.thresh])) / num_examples 23 | print("Accuracy @ {}: {}".format(args.thresh, accuracy)) 24 | 25 | -------------------------------------------------------------------------------- /comet-commonsense/scripts/classify/convert_conceptnet_generations_to_text.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | sys.path.append(os.getcwd()) 6 | 7 | import pickle 8 | 9 | import torch 10 | 11 | 12 | combine_into_words = { 13 | 'at location': 'AtLocation', 14 | 'capable of': 'CapableOf', 15 | 'causes': 'Causes', 16 | 'causes desire': 'CausesDesire', 17 | 'created by': 'CreatedBy', 18 | 'defined as': 'DefinedAs', 19 | 'desire of': 'DesireOf', 20 | 'desires': 'Desires', 21 | 'has a': 'HasA', 22 | 'has first subevent': 'HasFirstSubevent', 23 | 'has last subevent': 'HasLastSubevent', 24 | 'has pain character': 'HasPainCharacter', 25 | 'has pain intensity': 'HasPainIntensity', 26 | 'has prequisite': 'HasPrerequisite', 27 | 'has property': 'HasProperty', 28 | 'has subevent': 'HasSubevent', 29 | 'inherits from': 'InheritsFrom', 30 | 'instance of': 'InstanceOf', 31 | 'is a': 'IsA', 32 | 'located near': 'LocatedNear', 33 | 'location of action': 'LocationOfAction', 34 | 'made of': 'MadeOf', 35 | 'motivated by goal': 'MotivatedByGoal', 36 | 'not capable of': 'NotCapableOf', 37 | 'not desires': 'NotDesires', 38 | 'not has a': 'NotHasA', 39 | 'not has property': 'NotHasProperty', 40 | 'not is a': 'NotIsA', 41 | 'not made of': 'NotMadeOf', 42 | 'part of': 'PartOf', 43 | 'receives action': 'ReceivesAction', 44 | 'related to': 'RelatedTo', 45 | 'symbol of': 'SymbolOf', 46 | 'used for': 'UsedFor' 47 | } 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--gens_file", type=str, default="results/gens/conceptnet-generation/iteration-500-100000/transformer/rel_language-trainsize_100-devversion_12-maxe1_10-maxe2_15/model_transformer-nL_12-nH_12-hSize_768-edpt_0.1-adpt_0.1-rdpt_0.1-odpt_0.1-pt_gpt-afn_gelu-init_pt-vSize_40545/exp_generation-seed_123-l2_0.01-vl2_T-lrsched_warmup_linear-lrwarm_0.002-clip_1-loss_nll-b2_0.999-b1_0.9-e_1e-08/bs_1-smax_40-sample_greedy-numseq_1-gs_full-es_full/1e-05_adam_64_15500/test.pickle") 51 | 52 | args = parser.parse_args() 53 | 54 | gens = pickle.load(open(args.gens_file, "rb")) 55 | 56 | final_sequences = [] 57 | 58 | do_beams = False 59 | 60 | for idx, gen in enumerate(gens): 61 | e1 = gen["e1"].strip() 62 | r = gen["r"] 63 | 64 | if "rel_language" in args.gens_file or r.split(" ")[0] != r: 65 | r = combine_into_words[r.strip()] 66 | else: 67 | r = r.strip("<>") 68 | 69 | if "sequence" in gen: 70 | sequences = [gen['sequence']] 71 | else: 72 | sequences = gen['beams'] 73 | 74 | for seq in sequences: 75 | final_sequences.append("{}\t{}\t{}\t1".format(r, e1, seq)) 76 | 77 | final_sequences.append("") 78 | 79 | print("Saving to: {}".format(args.gens_file.replace("pickle", "txt"))) 80 | 81 | open(args.gens_file.replace("pickle", "txt"), "w").write("\n".join(final_sequences)) 82 | -------------------------------------------------------------------------------- /comet-commonsense/scripts/data/make_atomic_data_loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import src.data.data as data 7 | from utils.utils import DD 8 | import utils.utils as utils 9 | import random 10 | from src.data.utils import TextEncoder 11 | from tqdm import tqdm 12 | import torch 13 | 14 | # Manually change the set of categories you don't want to include 15 | # if you want to be able to train on a separate set of categories 16 | categories = [] 17 | categories += ["oEffect"] 18 | categories += ["oReact"] 19 | categories += ["oWant"] 20 | categories += ["xAttr"] 21 | categories += ["xEffect"] 22 | categories += ["xIntent"] 23 | categories += ["xNeed"] 24 | categories += ["xReact"] 25 | categories += ["xWant"] 26 | 27 | 28 | opt = DD() 29 | opt.dataset = "atomic" 30 | opt.exp = "generation" 31 | opt.data = DD() 32 | opt.data.categories = sorted(categories) 33 | 34 | encoder_path = "model/encoder_bpe_40000.json" 35 | bpe_path = "model/vocab_40000.bpe" 36 | 37 | text_encoder = TextEncoder(encoder_path, bpe_path) 38 | 39 | encoder = text_encoder.encoder 40 | n_vocab = len(text_encoder.encoder) 41 | 42 | special = [data.start_token, data.end_token] 43 | special += ["<{}>".format(cat) for cat in categories] 44 | special += [data.blank_token] 45 | 46 | for special_token in special: 47 | text_encoder.decoder[len(encoder)] = special_token 48 | encoder[special_token] = len(encoder) 49 | 50 | save_path = "data/atomic/processed/{}".format(opt.exp) 51 | utils.mkpath(save_path) 52 | 53 | save_name = os.path.join( 54 | save_path, "{}.pickle".format(utils.make_name_string(opt.data))) 55 | 56 | data_loader = data.make_data_loader(opt, categories) 57 | data_loader.load_data("data/atomic/") 58 | random.shuffle(data_loader.data["dev"]["total"]) 59 | 60 | data_loader.make_tensors(text_encoder, special, test=False) 61 | data_loader.reset_offsets() 62 | 63 | 64 | opt.data.maxe1 = data_loader.max_event 65 | opt.data.maxe2 = data_loader.max_effect 66 | opt.data.maxr = 1 67 | 68 | save_name = os.path.join( 69 | save_path, "{}.pickle".format(utils.make_name_string(opt.data))) 70 | 71 | print("Data Loader will be saved at: {}".format(save_name)) 72 | 73 | torch.save(data_loader, save_name) 74 | -------------------------------------------------------------------------------- /comet-commonsense/scripts/data/make_conceptnet_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import torch 7 | import src.data.conceptnet as cdata 8 | import src.data.data as data 9 | 10 | from utils.utils import DD 11 | import utils.utils as utils 12 | import random 13 | from src.data.utils import TextEncoder 14 | from tqdm import tqdm 15 | 16 | opt = DD() 17 | opt.dataset = "conceptnet" 18 | opt.exp = "generation" 19 | 20 | opt.data = DD() 21 | 22 | # Use relation embeddings rather than 23 | # splitting relations into its component words 24 | # Set to "language" for using component words 25 | # Set to "relation" to use unlearned relation embeddings 26 | opt.data.rel = "language" 27 | 28 | # Use 100k training set 29 | opt.data.trainsize = 100 30 | 31 | # Use both dev sets (v1 an v2) 32 | opt.data.devversion = "12" 33 | 34 | # Maximum token length of e1 35 | opt.data.maxe1 = 10 36 | 37 | # Maximum token length of e2 38 | opt.data.maxe2 = 15 39 | 40 | relations = [ 41 | 'AtLocation', 'CapableOf', 'Causes', 'CausesDesire', 'CreatedBy', 42 | 'DefinedAs', 'DesireOf', 'Desires', 'HasA', 'HasFirstSubevent', 43 | 'HasLastSubevent', 'HasPainCharacter', 'HasPainIntensity', 44 | 'HasPrerequisite', 'HasProperty', 'HasSubevent', 'InheritsFrom', 45 | 'InstanceOf', 'IsA', 'LocatedNear', 'LocationOfAction', 'MadeOf', 46 | 'MotivatedByGoal', 'NotCapableOf', 'NotDesires', 'NotHasA', 47 | 'NotHasProperty', 'NotIsA', 'NotMadeOf', 'PartOf', 'ReceivesAction', 48 | 'RelatedTo', 'SymbolOf', 'UsedFor' 49 | ] 50 | 51 | special = [data.start_token, data.end_token] 52 | special += ["<{}>".format(relation) for relation in relations] 53 | 54 | encoder_path = "model/encoder_bpe_40000.json" 55 | bpe_path = "model/vocab_40000.bpe" 56 | 57 | text_encoder = TextEncoder(encoder_path, bpe_path) 58 | 59 | for special_token in special: 60 | text_encoder.decoder[len(text_encoder.encoder)] = special_token 61 | text_encoder.encoder[special_token] = len(text_encoder.encoder) 62 | 63 | data_loader = cdata.GenerationDataLoader(opt) 64 | data_loader.load_data("data/conceptnet/") 65 | 66 | data_loader.make_tensors(text_encoder, special, test=False) 67 | 68 | opt.data.maxr = data_loader.max_r 69 | 70 | save_path = "data/conceptnet/processed/generation" 71 | save_name = os.path.join(save_path, "{}.pickle".format( 72 | utils.make_name_string(opt.data))) 73 | 74 | utils.mkpath(save_path) 75 | 76 | print("Data Loader will be saved to {}".format(save_name)) 77 | 78 | torch.save(data_loader, save_name) 79 | -------------------------------------------------------------------------------- /comet-commonsense/scripts/novelty/compute_conceptnet_novelty.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import argparse 5 | 6 | sys.path.append(os.getcwd()) 7 | 8 | import src.data.conceptnet as cdata 9 | 10 | combine_into_words = {j:i for i, j in cdata.split_into_words.items()} 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--gens_name", type=str, default="results/gens/conceptnet-generation/iteration-500-100000/transformer/rel_language-trainsize_100-devversion_12-maxe1_10-maxe2_15/model_transformer-nL_12-nH_12-hSize_768-edpt_0.1-adpt_0.1-rdpt_0.1-odpt_0.1-pt_gpt-afn_gelu-init_pt-vSize_40545/exp_generation-seed_123-l2_0.01-vl2_T-lrsched_warmup_linear-lrwarm_0.002-clip_1-loss_nll-b2_0.999-b1_0.9-e_1e-08/bs_1-smax_40-sample_greedy-numseq_1-gs_full-es_full/1e-05_adam_64_15500/test.txt") 14 | parser.add_argument("--training_set_file", type=str, default="data/concepnet/train100k.txt") 15 | 16 | args = parser.parse_args() 17 | 18 | gens = pickle.load(open(args.gens_name, "rb")) 19 | training_gens = [i.split("\t")[:3] for i in open(args.training_set_file, "r").read().split("\n") if i] 20 | 21 | evaluation_rels = [] 22 | evaluation_e2s = [] 23 | 24 | do_beams = len(gens[0]['beams']) > 1 25 | 26 | for idx, example in enumerate(gens): 27 | e1 = example["e1"] 28 | 29 | if example["r"].strip("<>") not in cdata.split_into_words: 30 | r = combine_into_words[example["r"]] 31 | else: 32 | r = example["r"].strip("<>") 33 | # 34 | if do_beams: 35 | examples_sequences = example["beams"] 36 | else: 37 | examples_sequences = [example['sequence']] 38 | 39 | for seq in examples_sequences: 40 | evaluation_rels.append((r.strip(), e1.strip(), seq.strip())) 41 | evaluation_e2s.append(seq.strip()) 42 | 43 | train_rels = set([tuple([j.strip() for j in i]) for i in training_gens]) 44 | train_e2s = set([i[2].strip() for i in training_gens]) 45 | 46 | print("% new o: {}".format(len([i for i in evaluation_e2s if i not in train_e2s]) / len(evaluation_rels))) 47 | print("% new sro: {}".format(len([i for i in evaluation_rels if i not in train_rels]) / len(evaluation_rels))) 48 | -------------------------------------------------------------------------------- /comet-commonsense/scripts/setup/get_atomic_data.sh: -------------------------------------------------------------------------------- 1 | wget https://homes.cs.washington.edu/~msap/atomic/data/atomic_data.tgz 2 | mkdir -p data/atomic 3 | mv atomic_data.tgz data/atomic 4 | 5 | tar -xvzf data/atomic/atomic_data.tgz -C data/atomic 6 | rm data/atomic/atomic_data.tgz 7 | -------------------------------------------------------------------------------- /comet-commonsense/scripts/setup/get_conceptnet_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data/conceptnet 2 | 3 | cd data/conceptnet 4 | 5 | wget https://ttic.uchicago.edu/~kgimpel/comsense_resources/train100k.txt.gz 6 | wget https://ttic.uchicago.edu/~kgimpel/comsense_resources/dev1.txt.gz 7 | wget https://ttic.uchicago.edu/~kgimpel/comsense_resources/dev2.txt.gz 8 | wget https://ttic.uchicago.edu/~kgimpel/comsense_resources/test.txt.gz 9 | 10 | gunzip train100k.txt.gz 11 | gunzip dev1.txt.gz 12 | gunzip dev2.txt.gz 13 | gunzip test.txt.gz 14 | 15 | cd .. -------------------------------------------------------------------------------- /comet-commonsense/scripts/setup/get_model_files.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/openai/finetune-transformer-lm.git 2 | mv finetune-transformer-lm/model . 3 | rm -rf finetune-transformer-lm -------------------------------------------------------------------------------- /comet-commonsense/src/data/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import src.data.atomic as atomic_data 3 | import src.data.conceptnet as conceptnet_data 4 | import src.data.config as cfg 5 | 6 | import utils.utils as utils 7 | 8 | import pickle 9 | import torch 10 | import json 11 | 12 | 13 | start_token = "" 14 | end_token = "" 15 | blank_token = "" 16 | 17 | 18 | def save_checkpoint(state, filename): 19 | print("Saving model to {}".format(filename)) 20 | torch.save(state, filename) 21 | 22 | 23 | def save_step(model, vocab, optimizer, opt, length, lrs): 24 | if cfg.test_save: 25 | name = "{}.pickle".format(utils.make_name( 26 | opt, prefix="garbage/models/", is_dir=False, eval_=True)) 27 | else: 28 | name = "{}.pickle".format(utils.make_name( 29 | opt, prefix="models/", is_dir=False, eval_=True)) 30 | save_checkpoint({ 31 | "epoch": length, "state_dict": model.state_dict(), 32 | "optimizer": optimizer.state_dict(), "opt": opt, 33 | "vocab": vocab, "epoch_learning_rates": lrs}, 34 | name) 35 | 36 | 37 | def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"): 38 | if cfg.test_save: 39 | name = "{}/{}.{}".format(utils.make_name( 40 | opt, prefix="garbage/{}/".format(eval_type), 41 | is_dir=True, eval_=True), split, ext) 42 | else: 43 | name = "{}/{}.{}".format(utils.make_name( 44 | opt, prefix="results/{}/".format(eval_type), 45 | is_dir=True, eval_=True), split, ext) 46 | print("Saving {} {} to {}".format(split, eval_type, name)) 47 | 48 | if ext == "pickle": 49 | with open(name, "wb") as f: 50 | pickle.dump(stats, f) 51 | elif ext == "txt": 52 | with open(name, "w") as f: 53 | f.write(stats) 54 | elif ext == "json": 55 | with open(name, "w") as f: 56 | json.dump(stats, f) 57 | else: 58 | raise 59 | 60 | 61 | def load_checkpoint(filename, gpu=True): 62 | if os.path.exists(filename): 63 | checkpoint = torch.load( 64 | filename, map_location=lambda storage, loc: storage) 65 | else: 66 | print("No model found at {}".format(filename)) 67 | return checkpoint 68 | 69 | 70 | def make_data_loader(opt, *args): 71 | if opt.dataset == "atomic": 72 | return atomic_data.GenerationDataLoader(opt, *args) 73 | elif opt.dataset == "conceptnet": 74 | return conceptnet_data.GenerationDataLoader(opt, *args) 75 | 76 | 77 | def set_max_sizes(data_loader, force_split=None): 78 | data_loader.total_size = {} 79 | if force_split is not None: 80 | data_loader.total_size[force_split] = \ 81 | data_loader.sequences[force_split]["total"].size(0) 82 | return 83 | for split in data_loader.sequences: 84 | data_loader.total_size[split] = \ 85 | data_loader.sequences[split]["total"].size(0) 86 | -------------------------------------------------------------------------------- /comet-commonsense/src/evaluate/atomic_evaluate.py: -------------------------------------------------------------------------------- 1 | import src.train.batch as batch 2 | import src.evaluate.evaluate as base_evaluate 3 | import numpy as np 4 | 5 | def make_evaluator(opt, *args): 6 | if opt.exp == "generation": 7 | return AtomicGenerationEvaluator(opt, *args) 8 | else: 9 | return AtomicClassificationEvaluator(opt, *args) 10 | 11 | 12 | class AtomicGenerationEvaluator(base_evaluate.Evaluator): 13 | def __init__(self, opt, model, data_loader): 14 | super(AtomicGenerationEvaluator, self).__init__( 15 | opt, model, data_loader) 16 | 17 | self.batch = batch.batch_atomic_generate 18 | 19 | def initialize_losses(self): 20 | average_loss = {"total_micro": 0, "total_macro": 0} 21 | nums = {"total_micro": 0, "total_macro": 0} 22 | return average_loss, nums 23 | 24 | def compute_final_scores(self, average_loss, nums): 25 | average_loss["total_macro"] /= nums["total_macro"] 26 | average_loss["total_micro"] /= nums["total_micro"] 27 | 28 | average_loss["ppl_macro"] = np.exp(average_loss["total_macro"]) 29 | average_loss["ppl_micro"] = np.exp(average_loss["total_micro"]) 30 | 31 | return average_loss 32 | 33 | def counter(self, nums): 34 | return nums["total_macro"] 35 | 36 | def print_result(self, split, epoch_losses): 37 | print("{} Loss: \t {}".format( 38 | split, epoch_losses["total_micro"])) 39 | print("{} Perplexity: \t {}".format( 40 | split, epoch_losses["ppl_micro"])) 41 | -------------------------------------------------------------------------------- /comet-commonsense/src/evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | import utils.utils as utils 5 | import src.data.config as cfg 6 | 7 | 8 | class Evaluator(object): 9 | def __init__(self, opt, model, data_loader): 10 | super(Evaluator, self).__init__() 11 | 12 | self.data_loader = data_loader 13 | self.model = model 14 | 15 | self.batch_variables = { 16 | "model": model, 17 | "data": data_loader 18 | } 19 | 20 | self.opt = opt 21 | 22 | def validate(self, l, split="dev", losses={}, keyset=None): 23 | self.batch_variables["split"] = split 24 | print("Evaluating {}".format(split)) 25 | 26 | epoch_losses = self.epoch( 27 | self.opt, self.model, self.data_loader, split, keyset) 28 | 29 | self.print_result(split, epoch_losses) 30 | 31 | for loss_name, loss_val in epoch_losses.items(): 32 | losses.setdefault(loss_name, {}) 33 | losses[loss_name][l] = loss_val 34 | 35 | def epoch(self, opt, model, data_loader, split, keyset=None): 36 | average_loss, nums = self.initialize_losses() 37 | 38 | data_loader.reset_offsets(splits=split, shuffle=False) 39 | 40 | # Set evaluation mode 41 | model.eval() 42 | 43 | start = time.time() 44 | 45 | # Initialize progress bar 46 | bar = utils.set_progress_bar( 47 | data_loader.total_size[split]) 48 | 49 | reset = False 50 | 51 | with torch.no_grad(): 52 | while not reset: 53 | 54 | start = data_loader.offset_summary(split) 55 | 56 | outputs = self.batch( 57 | opt, nums, average_loss, 58 | self.batch_variables, eval_mode=True) 59 | 60 | end = data_loader.offset_summary(split) 61 | 62 | reset = outputs["reset"] 63 | 64 | if not reset: 65 | bar.update(end - start) 66 | else: 67 | print(end) 68 | 69 | if cfg.toy and self.counter(nums) > 100: 70 | break 71 | if (opt.eval.es != "full" and 72 | (self.counter(nums) > opt.eval.es)): 73 | break 74 | 75 | nums = outputs["nums"] 76 | 77 | torch.cuda.synchronize() 78 | 79 | print("{} evaluation completed in: {} s".format( 80 | split.capitalize(), time.time() - start)) 81 | 82 | average_loss = self.compute_final_scores( 83 | average_loss, nums) 84 | 85 | return average_loss 86 | -------------------------------------------------------------------------------- /comet-commonsense/src/evaluate/generate.py: -------------------------------------------------------------------------------- 1 | import src.data.data as data 2 | import src.data.config as cfg 3 | import src.evaluate.sampler as sampling 4 | 5 | 6 | def do_gen_run(opt, generator, l, split="dev", scores={}): 7 | # Generate sequences for examples in evaluation set using 8 | # current trained model 9 | 10 | if opt.eval.gs == "full": 11 | sequences, avg_scores, indiv_scores = generator.generate(split) 12 | else: 13 | sequences, avg_scores, indiv_scores = generator.generate_some(split) 14 | 15 | if avg_scores is not None: 16 | # Record scores from generated sequences 17 | for score_name, score_val in avg_scores.items(): 18 | scores.setdefault(score_name, {}) 19 | scores[score_name].setdefault(l, []) 20 | scores[score_name][l] += [score_val] 21 | 22 | # Save generated sequences 23 | save_sequences(opt, sequences, avg_scores, indiv_scores, 24 | l, split, opt.eval.gs == "full", 25 | generator.data_loader) 26 | 27 | 28 | def save_sequences(opt, sequences, avg_scores, indiv_scores, 29 | l, split, full, data_loader): 30 | # This seems a bit roundabout since l = opt.train.dynamic in train.py 31 | # But it's in case we start checkpointing outside of epoch boundaries 32 | opt.train.dynamic.epoch = l 33 | 34 | if cfg.save: 35 | if full: 36 | names = {"gens": "gens", "scores": "scores", 37 | "indiv": "indiv.scores"} 38 | else: 39 | names = {"gens": "gens.small", "scores": "scores.small", 40 | "indiv": "indiv.scores.small"} 41 | # Save generated sequences 42 | data.save_eval_file(opt, sequences, names["gens"], split) 43 | 44 | if avg_scores is not None: 45 | # Save average scores over evaluation set for generated sequences 46 | # Scores computed are the ones the generator was initialized with 47 | data.save_eval_file(opt, avg_scores, names["scores"], split) 48 | 49 | if split == "dev": 50 | # Save individual scores 51 | data.save_eval_file( 52 | opt, indiv_scores, names["indiv"], split) 53 | 54 | 55 | class Generator(object): 56 | def __init__(self, opt, model, data_loader, scorers, reward_function=None): 57 | super(Generator, self).__init__() 58 | self.opt = opt 59 | 60 | self.model = model 61 | self.data_loader = data_loader 62 | 63 | self.sampler = sampling.make_sampler( 64 | opt.eval.sample, opt, data_loader) 65 | 66 | 67 | def generate(self, split="dev"): 68 | pass 69 | 70 | def generate_batch(self, sequences, split, verbose=False, bs=32): 71 | pass 72 | 73 | -------------------------------------------------------------------------------- /comet-commonsense/src/evaluate/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def update_classification_losses(losses, nums, name, bs, loss): 3 | if not isinstance(loss, float): 4 | print(type(loss)) 5 | raise 6 | 7 | nums[name] += bs 8 | 9 | losses[name] += loss * bs 10 | 11 | 12 | def update_generation_losses(losses, nums, micro, macro, bs, length, loss): 13 | # Update Losses 14 | nums[macro] += bs 15 | 16 | if isinstance(length, int): 17 | update_indiv_generation_losses( 18 | losses, nums, micro, macro, bs, length, loss) 19 | else: 20 | update_tensor_generation_losses( 21 | losses, nums, micro, macro, bs, length, loss) 22 | 23 | 24 | def update_indiv_generation_losses(losses, nums, micro, 25 | macro, bs, length, loss): 26 | nums[micro] += bs * length 27 | 28 | batch_loss = loss * bs 29 | 30 | losses[micro] += batch_loss 31 | losses[macro] += batch_loss / length 32 | 33 | 34 | def update_tensor_generation_losses(losses, nums, micro, 35 | macro, bs, length, loss): 36 | nums[micro] += length.sum().item() 37 | 38 | losses[micro] += loss.sum().item() 39 | losses[macro] += (loss / length.float()).sum().item() 40 | -------------------------------------------------------------------------------- /comet-commonsense/src/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | sys.path.append(os.getcwd()) 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--experiment_type", type=str, default='atomic', 9 | choices=["atomic", "conceptnet"]) 10 | parser.add_argument("--experiment_num", type=str, default="0") 11 | 12 | args = parser.parse_args() 13 | 14 | if args.experiment_type == "atomic": 15 | from main_atomic import main 16 | main(args.experiment_num) 17 | if args.experiment_type == "conceptnet": 18 | from main_conceptnet import main 19 | main(args.experiment_num) 20 | -------------------------------------------------------------------------------- /comet-commonsense/src/models/models.py: -------------------------------------------------------------------------------- 1 | from src.models.gpt import (LMModel, DEFAULT_CONFIG, load_openai_pretrained_model) 2 | import torch.nn as nn 3 | 4 | 5 | def make_model(opt, n_vocab, n_ctx, n_special, load=True, 6 | return_acts=True, return_probs=False, 7 | clf_token="", answer_size=None): 8 | print(n_ctx) 9 | if opt.exp == "generation": 10 | model = LMModel( 11 | opt.net, n_vocab, n_ctx, return_acts=return_acts, 12 | return_probs=return_probs) 13 | elif opt.exp == "classification": 14 | model = ClfModel( 15 | opt.net, n_vocab, n_ctx, clf_token, answer_size) 16 | if load: 17 | print("LOADING PRETRAINED TRANSFORMER") 18 | load_openai_pretrained_model( 19 | model.transformer, n_ctx=n_ctx, n_special=n_special) 20 | return model 21 | 22 | 23 | def multi_gpu(model, devices): 24 | return nn.DataParallel(model, device_ids=devices) 25 | 26 | 27 | def load_state_dict(model, state_dict): 28 | try: 29 | model.load_state_dict(state_dict) 30 | except RuntimeError: 31 | new_state_dict = {i[len("module."):]: j for i, j in state_dict.items()} 32 | model.load_state_dict(new_state_dict) 33 | -------------------------------------------------------------------------------- /comet-commonsense/src/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def prepare_position_embeddings(opt, encoder_vocab, sequences): 5 | vocab_size = len(encoder_vocab) 6 | num_positions = sequences.size(-2) 7 | position_embeddings = torch.LongTensor( 8 | range(vocab_size, vocab_size + num_positions)).to(sequences.device) 9 | sequences = sequences.repeat(1, 1, 2) 10 | sequences[:, :, 1] = position_embeddings 11 | return sequences 12 | 13 | -------------------------------------------------------------------------------- /comet-commonsense/src/train/atomic_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import src.train.train as base_train 4 | import src.train.batch as batch 5 | import src.evaluate.atomic_evaluate as evaluate 6 | # import src.evaluate.atomic_generate as gen 7 | 8 | 9 | def make_trainer(opt, *args): 10 | return AtomicGenerationIteratorTrainer(opt, *args) 11 | 12 | 13 | class AtomicGenerationIteratorTrainer(base_train.IteratorTrainer): 14 | def __init__(self, opt, *args): 15 | super(AtomicGenerationIteratorTrainer, self).__init__(opt, *args) 16 | 17 | self.initialize_losses(opt.data.get("categories", [])) 18 | 19 | def set_evaluator(self, opt, model, data_loader): 20 | self.evaluator = evaluate.make_evaluator( 21 | opt, model, data_loader) 22 | 23 | # def set_generator(self, opt, model, data_loader, scores, reward=None): 24 | # self.generator = gen.make_generator( 25 | # opt, model, data_loader, scores, reward) 26 | 27 | def set_sampler(self, opt): 28 | if opt.train.static.samp not in self.samplers: 29 | self.samplers[opt.train.static.samp] = sampling.make_sampler( 30 | opt.train.static.samp, opt, self.data_loader, batch_mode=True) 31 | self.batch_variables["sampler"] = self.samplers 32 | 33 | def batch(self, opt, *args): 34 | outputs = batch.batch_atomic_generate(opt, *args) 35 | 36 | token_loss = outputs["loss"] 37 | nums = outputs["nums"] 38 | reset = outputs["reset"] 39 | 40 | return token_loss, nums, reset 41 | 42 | def initialize_losses(self, categories): 43 | self.losses["train"] = { 44 | "total_micro": [0], 45 | "total_macro": [0] 46 | } 47 | 48 | nums = {"total_micro": 0, "total_macro": 0} 49 | 50 | for category in categories: 51 | micro_name = "{}_micro".format(category) 52 | macro_name = "{}_macro".format(category) 53 | 54 | self.losses["train"][micro_name] = [0] 55 | self.losses["train"][macro_name] = [0] 56 | 57 | nums[micro_name] = 0 58 | nums[macro_name] = 0 59 | 60 | return nums 61 | 62 | def update_top_score(self, opt): 63 | print(self.top_score) 64 | if self.top_score is None: 65 | self.top_score = (self.opt.train.dynamic.epoch, 66 | self.get_tracked_score()) 67 | elif self.get_tracked_score() < self.top_score[-1]: 68 | self.top_score = (self.opt.train.dynamic.epoch, 69 | self.get_tracked_score()) 70 | print(self.top_score) 71 | 72 | def get_tracked_score(self): 73 | return self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch] 74 | 75 | def counter(self, nums): 76 | return nums["total_macro"] 77 | -------------------------------------------------------------------------------- /comet-commonsense/src/train/conceptnet_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | import src.data.config as cfg 5 | 6 | import src.train.atomic_train as base_train 7 | import src.train.batch as batch_utils 8 | import src.evaluate.conceptnet_evaluate as evaluate 9 | import src.evaluate.conceptnet_generate as gen 10 | 11 | 12 | def make_trainer(opt, *args): 13 | return ConceptNetGenerationIteratorTrainer(opt, *args) 14 | 15 | 16 | class ConceptNetGenerationIteratorTrainer( 17 | base_train.AtomicGenerationIteratorTrainer): 18 | def set_evaluator(self, opt, model, data_loader): 19 | self.evaluator = evaluate.make_evaluator( 20 | opt, model, data_loader) 21 | 22 | def set_generator(self, opt, model, data_loader): 23 | self.generator = gen.make_generator( 24 | opt, model, data_loader) 25 | 26 | def batch(self, opt, *args): 27 | outputs = batch_utils.batch_atomic_generate(opt, *args) 28 | 29 | token_loss = outputs["loss"] 30 | nums = outputs["nums"] 31 | reset = outputs["reset"] 32 | 33 | return token_loss, nums, reset 34 | 35 | def update_top_score(self, opt): 36 | print(self.top_score) 37 | 38 | tracked_scores = self.get_tracked_score() 39 | 40 | if self.top_score is None: 41 | self.top_score = \ 42 | self.top_score = {"epoch": {}, "score": {}} 43 | self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch 44 | self.top_score["score"]["total_micro"] = tracked_scores["total_micro"] 45 | else: 46 | if tracked_scores["total_micro"] < self.top_score["score"]["total_micro"]: 47 | self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch 48 | self.top_score["score"]["total_micro"] = tracked_scores["total_micro"] 49 | 50 | print(self.top_score) 51 | 52 | def get_tracked_score(self): 53 | return { 54 | "total_micro": self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch] 55 | } 56 | 57 | def decide_to_save(self): 58 | to_save = cfg.save and not cfg.toy 59 | 60 | curr_epoch = self.opt.train.dynamic.epoch 61 | 62 | to_save = to_save or cfg.test_save 63 | print(cfg.save_strategy) 64 | if cfg.save_strategy == "best": 65 | if ((self.top_score["epoch"]["total_micro"] != curr_epoch)): 66 | to_save = False 67 | return to_save 68 | -------------------------------------------------------------------------------- /comet-commonsense/src/train/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim 3 | import torch.nn.functional as F 4 | 5 | import copy 6 | 7 | 8 | def update_generation_losses(losses, nums, micro, macro, bs, length, loss): 9 | # Update Losses 10 | losses[micro] += \ 11 | [copy.deepcopy(losses[micro][-1])] 12 | losses[macro] += \ 13 | [copy.deepcopy(losses[macro][-1])] 14 | 15 | losses[micro][-1] *= nums[micro] 16 | losses[macro][-1] *= nums[macro] 17 | 18 | nums[macro] += bs 19 | 20 | if isinstance(length, int): 21 | update_indiv_generation_losses( 22 | losses, nums, micro, macro, bs, length, loss) 23 | else: 24 | update_tensor_generation_losses( 25 | losses, nums, micro, macro, bs, length, loss) 26 | 27 | 28 | def update_indiv_generation_losses(losses, nums, micro, 29 | macro, bs, length, loss): 30 | nums[micro] += (bs * length) 31 | 32 | batch_loss = loss * bs 33 | 34 | losses[micro][-1] += batch_loss 35 | losses[micro][-1] /= nums[micro] 36 | losses[macro][-1] += batch_loss / length 37 | losses[macro][-1] /= nums[macro] 38 | 39 | 40 | def update_tensor_generation_losses(losses, nums, micro, 41 | macro, bs, length, loss): 42 | nums[micro] += length.sum().item() 43 | 44 | losses[micro][-1] += loss.sum().item() 45 | losses[micro][-1] /= nums[micro] 46 | losses[macro][-1] += (loss / length.float()).sum().item() 47 | losses[macro][-1] /= nums[macro] 48 | 49 | 50 | def modify_output_for_loss_fn(loss_fn, output, dim): 51 | if loss_fn == "ce": 52 | return output 53 | if loss_fn == "mse": 54 | return F.softmax(output, dim=dim) 55 | if loss_fn == "nll": 56 | return F.log_softmax(output, dim=dim) 57 | if loss_fn in ["bce", "wbce", "wbce1"]: 58 | return torch.sigmoid(output) 59 | -------------------------------------------------------------------------------- /convert_to_literal.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import yaml 3 | import string 4 | import json 5 | import os 6 | import sys 7 | import ast 8 | from urllib.parse import quote 9 | import math 10 | from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel 11 | import ast 12 | import torch 13 | import nltk 14 | 15 | sys.path.append(os.getcwd()+'/comet-commonsense') 16 | 17 | def getParams(): 18 | model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt') 19 | model.eval() 20 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 21 | return model,tokenizer 22 | 23 | def getscore(sentence,tokenizer,model): 24 | tokenize_input = tokenizer.tokenize(sentence) 25 | tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)]) 26 | loss=model(tensor_input, lm_labels=tensor_input) 27 | return loss 28 | 29 | def getCommonSense(utterance): 30 | os.system('python comet-commonsense/scripts/generate/generate_conceptnet_arbitrary.py --model_file comet-commonsense/pretrained_models/conceptnet_pretrained_model.pickle --input "'+utterance+'" --output_file output.json --device 0 --sampling_algorithm beam-5') 31 | output = json.load(open('output.json', "r")) 32 | return output[0]['HasProperty']['beams'] 33 | 34 | def create_literal(simile): 35 | model,tokenizer = getParams() 36 | vehicle = simile.split(' like a ')[1] 37 | tenor_event_comp = simile.split(' like a ')[0]+' like a' 38 | vehicle_property = getCommonSense(vehicle) 39 | scores = [] 40 | for elem in vehicle_property: 41 | if elem not in vehicle: 42 | scores.append((getscore(tenor_event_comp+' '+elem,tokenizer,model),tenor_event_comp+' '+elem)) 43 | scores.sort(key = lambda x: x[0],reverse=True) 44 | best_literal = scores[0][1] 45 | return best_literal 46 | 47 | 48 | print(create_literal('Rare and forgotten words are like a strong spice')) 49 | 50 | 51 | -------------------------------------------------------------------------------- /create_bpe.sh: -------------------------------------------------------------------------------- 1 | for SPLIT in train val 2 | do 3 | for LANG in source target 4 | do 5 | python -m examples.roberta.multiprocessing_bpe_encoder \ 6 | --encoder-json encoder.json \ 7 | --vocab-bpe vocab.bpe \ 8 | --inputs "simile/$SPLIT.$LANG" \ 9 | --outputs "simile/$SPLIT.bpe.$LANG" \ 10 | --workers 60 \ 11 | --keep-empty; 12 | done 13 | done 14 | -------------------------------------------------------------------------------- /fairseq.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fairseq.egg-info/entry_points.txt: -------------------------------------------------------------------------------- 1 | [console_scripts] 2 | fairseq-eval-lm = fairseq_cli.eval_lm:cli_main 3 | fairseq-generate = fairseq_cli.generate:cli_main 4 | fairseq-interactive = fairseq_cli.interactive:cli_main 5 | fairseq-preprocess = fairseq_cli.preprocess:cli_main 6 | fairseq-score = fairseq_cli.score:main 7 | fairseq-train = fairseq_cli.train:cli_main 8 | fairseq-validate = fairseq_cli.validate:cli_main 9 | 10 | -------------------------------------------------------------------------------- /fairseq.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fairseq.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | cython 3 | numpy 4 | regex 5 | sacrebleu 6 | torch 7 | tqdm 8 | -------------------------------------------------------------------------------- /fairseq.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | examples 2 | fairseq 3 | fairseq_cli 4 | tests 5 | -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = ['pdb'] 7 | __version__ = '0.9.0' 8 | 9 | import fairseq.criterions # noqa 10 | import fairseq.models # noqa 11 | import fairseq.modules # noqa 12 | import fairseq.optim # noqa 13 | import fairseq.optim.lr_scheduler # noqa 14 | import fairseq.pdb # noqa 15 | import fairseq.tasks # noqa 16 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/clib/libnat_cuda/binding.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /* 10 | This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance 11 | */ 12 | 13 | #include "edit_dist.h" 14 | #include 15 | 16 | #ifndef TORCH_CHECK 17 | #define TORCH_CHECK AT_CHECK 18 | #endif 19 | 20 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 21 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 22 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 23 | 24 | 25 | torch::Tensor LevenshteinDistance( 26 | torch::Tensor source, 27 | torch::Tensor target, 28 | torch::Tensor source_length, 29 | torch::Tensor target_length) { 30 | 31 | CHECK_INPUT(source); 32 | CHECK_INPUT(target); 33 | CHECK_INPUT(source_length); 34 | CHECK_INPUT(target_length); 35 | return LevenshteinDistanceCuda(source, target, source_length, target_length); 36 | } 37 | 38 | torch::Tensor GenerateDeletionLabel( 39 | torch::Tensor source, 40 | torch::Tensor operations) { 41 | 42 | CHECK_INPUT(source); 43 | CHECK_INPUT(operations); 44 | return GenerateDeletionLabelCuda(source, operations); 45 | } 46 | 47 | std::pair GenerateInsertionLabel( 48 | torch::Tensor target, 49 | torch::Tensor operations) { 50 | 51 | CHECK_INPUT(target); 52 | CHECK_INPUT(operations); 53 | return GenerateInsertionLabelCuda(target, operations); 54 | } 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); 58 | m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label"); 59 | m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label"); 60 | } 61 | -------------------------------------------------------------------------------- /fairseq/clib/libnat_cuda/edit_dist.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | torch::Tensor LevenshteinDistanceCuda( 14 | torch::Tensor source, 15 | torch::Tensor target, 16 | torch::Tensor source_length, 17 | torch::Tensor target_length); 18 | 19 | torch::Tensor GenerateDeletionLabelCuda( 20 | torch::Tensor source, 21 | torch::Tensor operations); 22 | 23 | std::pair GenerateInsertionLabelCuda( 24 | torch::Tensor source, 25 | torch::Tensor operations); 26 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.criterions.fairseq_criterion import FairseqCriterion 11 | 12 | 13 | build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( 14 | '--criterion', 15 | base_class=FairseqCriterion, 16 | default='cross_entropy', 17 | ) 18 | 19 | 20 | # automatically import any Python files in the criterions/ directory 21 | for file in os.listdir(os.path.dirname(__file__)): 22 | if file.endswith('.py') and not file.startswith('_'): 23 | module = file[:file.find('.py')] 24 | importlib.import_module('fairseq.criterions.' + module) 25 | -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Dict, List 7 | 8 | from torch.nn.modules.loss import _Loss 9 | 10 | from fairseq import metrics 11 | from fairseq import utils1 as utils 12 | 13 | 14 | class FairseqCriterion(_Loss): 15 | 16 | def __init__(self, args, task): 17 | super().__init__() 18 | self.args = args 19 | self.task = task 20 | self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add criterion-specific arguments to the parser.""" 25 | pass 26 | 27 | @classmethod 28 | def build_criterion(cls, args, task): 29 | return cls(args, task) 30 | 31 | def forward(self, model, sample, reduce=True): 32 | """Compute the loss for the given sample. 33 | 34 | Returns a tuple with three elements: 35 | 1) the loss 36 | 2) the sample size, which is used as the denominator for the gradient 37 | 3) logging outputs to display while training 38 | """ 39 | raise NotImplementedError 40 | 41 | @staticmethod 42 | def aggregate_logging_outputs( 43 | logging_outputs: List[Dict[str, Any]], 44 | ) -> Dict[str, Any]: 45 | """Aggregate logging outputs from data parallel training.""" 46 | utils.deprecation_warning( 47 | 'The aggregate_logging_outputs API is deprecated. ' 48 | 'Please use the reduce_metrics API instead.' 49 | ) 50 | raise NotImplementedError 51 | 52 | @classmethod 53 | def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: 54 | """Aggregate logging outputs from data parallel training.""" 55 | utils.deprecation_warning( 56 | 'Criterions should implement the reduce_metrics API. ' 57 | 'Falling back to deprecated aggregate_logging_outputs API.' 58 | ) 59 | agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) 60 | for k, v in agg_logging_outputs.items(): 61 | if k in {'nsentences', 'ntokens', 'sample_size'}: 62 | continue 63 | metrics.log_scalar(k, v) 64 | 65 | @staticmethod 66 | def logging_outputs_can_be_summed() -> bool: 67 | """ 68 | Whether the logging outputs returned by `forward` can be summed 69 | across workers prior to calling `reduce_metrics`. Setting this 70 | to True will improves distributed training speed. 71 | """ 72 | return False 73 | -------------------------------------------------------------------------------- /fairseq/data/append_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class AppendTokenDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, token=None): 15 | super().__init__(dataset) 16 | self.token = token 17 | if token is not None: 18 | self._sizes = np.array(dataset.sizes) + 1 19 | else: 20 | self._sizes = dataset.sizes 21 | 22 | def __getitem__(self, idx): 23 | item = self.dataset[idx] 24 | if self.token is not None: 25 | item = torch.cat([item, item.new([self.token])]) 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return self._sizes 31 | 32 | def num_tokens(self, index): 33 | n = self.dataset.num_tokens(index) 34 | if self.token is not None: 35 | n += 1 36 | return n 37 | 38 | def size(self, index): 39 | n = self.dataset.size(index) 40 | if self.token is not None: 41 | n += 1 42 | return n 43 | -------------------------------------------------------------------------------- /fairseq/data/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/fairseq/data/audio/__init__.py -------------------------------------------------------------------------------- /fairseq/data/base_wrapper_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class BaseWrapperDataset(FairseqDataset): 12 | 13 | def __init__(self, dataset): 14 | super().__init__() 15 | self.dataset = dataset 16 | 17 | def __getitem__(self, index): 18 | return self.dataset[index] 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def collater(self, samples): 24 | if hasattr(self.dataset, 'collater'): 25 | return self.dataset.collater(samples) 26 | else: 27 | return default_collate(samples) 28 | 29 | @property 30 | def sizes(self): 31 | return self.dataset.sizes 32 | 33 | def num_tokens(self, index): 34 | return self.dataset.num_tokens(index) 35 | 36 | def size(self, index): 37 | return self.dataset.size(index) 38 | 39 | def ordered_indices(self): 40 | return self.dataset.ordered_indices() 41 | 42 | @property 43 | def supports_prefetch(self): 44 | return getattr(self.dataset, 'supports_prefetch', False) 45 | 46 | def prefetch(self, indices): 47 | self.dataset.prefetch(indices) 48 | 49 | def set_epoch(self, epoch): 50 | super().set_epoch(epoch) 51 | if hasattr(self.dataset, 'set_epoch'): 52 | self.dataset.set_epoch(epoch) 53 | -------------------------------------------------------------------------------- /fairseq/data/colorize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class ColorizeDataset(BaseWrapperDataset): 12 | """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """ 13 | def __init__(self, dataset, color_getter): 14 | super().__init__(dataset) 15 | self.color_getter = color_getter 16 | 17 | def collater(self, samples): 18 | base_collate = super().collater(samples) 19 | if len(base_collate) > 0: 20 | base_collate["net_input"]["colors"] = torch.tensor( 21 | list(self.color_getter(self.dataset, s["id"]) for s in samples), 22 | dtype=torch.long, 23 | ) 24 | return base_collate 25 | -------------------------------------------------------------------------------- /fairseq/data/concat_sentences_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class ConcatSentencesDataset(FairseqDataset): 12 | 13 | def __init__(self, *datasets): 14 | super().__init__() 15 | self.datasets = datasets 16 | assert all(len(ds) == len(datasets[0]) for ds in datasets), \ 17 | 'datasets must have the same length' 18 | 19 | def __getitem__(self, index): 20 | return torch.cat([ds[index] for ds in self.datasets]) 21 | 22 | def __len__(self): 23 | return len(self.datasets[0]) 24 | 25 | def collater(self, samples): 26 | return self.datasets[0].collater(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return sum(ds.sizes for ds in self.datasets) 31 | 32 | def num_tokens(self, index): 33 | return sum(ds.num_tokens(index) for ds in self.datasets) 34 | 35 | def size(self, index): 36 | return sum(ds.size(index) for ds in self.datasets) 37 | 38 | def ordered_indices(self): 39 | return self.datasets[0].ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return any( 44 | getattr(ds, 'supports_prefetch', False) for ds in self.datasets 45 | ) 46 | 47 | def prefetch(self, indices): 48 | for ds in self.datasets: 49 | if getattr(ds, 'supports_prefetch', False): 50 | ds.prefetch(indices) 51 | 52 | def set_epoch(self, epoch): 53 | super().set_epoch(epoch) 54 | for ds in self.datasets: 55 | if hasattr(ds, 'set_epoch'): 56 | ds.set_epoch(epoch) 57 | -------------------------------------------------------------------------------- /fairseq/data/data_utils_fast.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/fairseq/data/data_utils_fast.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/data/data_utils_fast.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | cimport cython 10 | cimport numpy as np 11 | 12 | DTYPE = np.int64 13 | ctypedef np.int64_t DTYPE_t 14 | 15 | 16 | cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): 17 | if len(batch) == 0: 18 | return 0 19 | if max_sentences > 0 and len(batch) == max_sentences: 20 | return 1 21 | if max_tokens > 0 and num_tokens > max_tokens: 22 | return 1 23 | return 0 24 | 25 | 26 | @cython.cdivision(True) 27 | cpdef list batch_by_size_fast( 28 | np.ndarray[DTYPE_t, ndim=1] indices, 29 | num_tokens_fn, 30 | long max_tokens, 31 | long max_sentences, 32 | int bsz_mult, 33 | ): 34 | cdef long sample_len = 0 35 | cdef list sample_lens = [] 36 | cdef list batch = [] 37 | cdef list batches = [] 38 | cdef long mod_len 39 | cdef long i 40 | cdef long idx 41 | cdef long num_tokens 42 | cdef DTYPE_t[:] indices_view = indices 43 | 44 | for i in range(len(indices_view)): 45 | idx = indices_view[i] 46 | num_tokens = num_tokens_fn(idx) 47 | sample_lens.append(num_tokens) 48 | sample_len = max(sample_len, num_tokens) 49 | 50 | assert max_tokens <= 0 or sample_len <= max_tokens, ( 51 | "sentence at index {} of size {} exceeds max_tokens " 52 | "limit of {}!".format(idx, sample_len, max_tokens) 53 | ) 54 | num_tokens = (len(batch) + 1) * sample_len 55 | 56 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 57 | mod_len = max( 58 | bsz_mult * (len(batch) // bsz_mult), 59 | len(batch) % bsz_mult, 60 | ) 61 | batches.append(batch[:mod_len]) 62 | batch = batch[mod_len:] 63 | sample_lens = sample_lens[mod_len:] 64 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 65 | batch.append(idx) 66 | if len(batch) > 0: 67 | batches.append(batch) 68 | return batches 69 | -------------------------------------------------------------------------------- /fairseq/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( 14 | '--tokenizer', 15 | default=None, 16 | ) 17 | 18 | 19 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry( 20 | '--bpe', 21 | default=None, 22 | ) 23 | 24 | 25 | # automatically import any Python files in the encoders/ directory 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith('.py') and not file.startswith('_'): 28 | module = file[:file.find('.py')] 29 | importlib.import_module('fairseq.data.encoders.' + module) 30 | -------------------------------------------------------------------------------- /fairseq/data/encoders/fastbpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('fastbpe') 11 | class fastBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--bpe-codes', type=str, 17 | help='path to fastBPE BPE') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | if args.bpe_codes is None: 22 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt') 23 | codes = file_utils.cached_path(args.bpe_codes) 24 | try: 25 | import fastBPE 26 | self.bpe = fastBPE.fastBPE(codes) 27 | self.bpe_symbol = "@@ " 28 | except ImportError: 29 | raise ImportError('Please install fastBPE with: pip install fastBPE') 30 | 31 | def encode(self, x: str) -> str: 32 | return self.bpe.apply([x])[0] 33 | 34 | def decode(self, x: str) -> str: 35 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 36 | -------------------------------------------------------------------------------- /fairseq/data/encoders/gpt2_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | from .gpt2_bpe_utils import get_encoder 10 | 11 | 12 | DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' 13 | DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' 14 | 15 | 16 | @register_bpe('gpt2') 17 | class GPT2BPE(object): 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | # fmt: off 22 | parser.add_argument('--gpt2-encoder-json', type=str, 23 | default=DEFAULT_ENCODER_JSON, 24 | help='path to encoder.json') 25 | parser.add_argument('--gpt2-vocab-bpe', type=str, 26 | default=DEFAULT_VOCAB_BPE, 27 | help='path to vocab.bpe') 28 | # fmt: on 29 | 30 | def __init__(self, args): 31 | encoder_json = file_utils.cached_path( 32 | getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON) 33 | ) 34 | vocab_bpe = file_utils.cached_path( 35 | getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE) 36 | ) 37 | self.bpe = get_encoder(encoder_json, vocab_bpe) 38 | 39 | def encode(self, x: str) -> str: 40 | return ' '.join(map(str, self.bpe.encode(x))) 41 | 42 | def decode(self, x: str) -> str: 43 | return self.bpe.decode([ 44 | int(tok) if tok not in {'', ''} else tok 45 | for tok in x.split() 46 | ]) 47 | 48 | def is_beginning_of_word(self, x: str) -> bool: 49 | return self.decode(x).startswith(' ') 50 | -------------------------------------------------------------------------------- /fairseq/data/encoders/hf_bert_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_bpe 7 | 8 | 9 | @register_bpe('bert') 10 | class BertBPE(object): 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--bpe-cased', action='store_true', 16 | help='set for cased BPE', 17 | default=False) 18 | parser.add_argument('--bpe-vocab-file', type=str, 19 | help='bpe vocab file.') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | try: 24 | from pytorch_transformers import BertTokenizer 25 | from pytorch_transformers.tokenization_utils import clean_up_tokenization 26 | except ImportError: 27 | raise ImportError( 28 | 'Please install 1.0.0 version of pytorch_transformers' 29 | 'with: pip install pytorch-transformers' 30 | ) 31 | 32 | if 'bpe_vocab_file' in args: 33 | self.bert_tokenizer = BertTokenizer( 34 | args.bpe_vocab_file, 35 | do_lower_case=not args.bpe_cased 36 | ) 37 | else: 38 | vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased' 39 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) 40 | self.clean_up_tokenization = clean_up_tokenization 41 | 42 | def encode(self, x: str) -> str: 43 | return ' '.join(self.bert_tokenizer.tokenize(x)) 44 | 45 | def decode(self, x: str) -> str: 46 | return self.clean_up_tokenization( 47 | self.bert_tokenizer.convert_tokens_to_string(x.split(' ')) 48 | ) 49 | 50 | def is_beginning_of_word(self, x: str) -> bool: 51 | return not x.startswith('##') 52 | -------------------------------------------------------------------------------- /fairseq/data/encoders/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer('moses') 10 | class MosesTokenizer(object): 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--moses-source-lang', metavar='SRC', 16 | help='source language') 17 | parser.add_argument('--moses-target-lang', metavar='TARGET', 18 | help='target language') 19 | parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, 20 | help='don\'t apply dash split rules') 21 | parser.add_argument('--moses-no-escape', action='store_true', default=False, 22 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.') 23 | # fmt: on 24 | 25 | def __init__(self, args): 26 | self.args = args 27 | 28 | if getattr(args, 'moses_source_lang', None) is None: 29 | args.moses_source_lang = getattr(args, 'source_lang', 'en') 30 | if getattr(args, 'moses_target_lang', None) is None: 31 | args.moses_target_lang = getattr(args, 'target_lang', 'en') 32 | 33 | try: 34 | from sacremoses import MosesTokenizer, MosesDetokenizer 35 | self.tok = MosesTokenizer(args.moses_source_lang) 36 | self.detok = MosesDetokenizer(args.moses_target_lang) 37 | except ImportError: 38 | raise ImportError('Please install Moses tokenizer with: pip install sacremoses') 39 | 40 | def encode(self, x: str) -> str: 41 | return self.tok.tokenize( 42 | x, 43 | aggressive_dash_splits=(not self.args.moses_no_dash_splits), 44 | return_str=True, 45 | escape=(not self.args.moses_no_escape), 46 | ) 47 | 48 | def decode(self, x: str) -> str: 49 | return self.detok.detokenize(x.split()) 50 | -------------------------------------------------------------------------------- /fairseq/data/encoders/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer('nltk') 10 | class NLTKTokenizer(object): 11 | 12 | def __init__(self, source_lang=None, target_lang=None): 13 | try: 14 | from nltk.tokenize import word_tokenize 15 | self.word_tokenize = word_tokenize 16 | except ImportError: 17 | raise ImportError('Please install nltk with: pip install nltk') 18 | 19 | def encode(self, x: str) -> str: 20 | return ' '.join(self.word_tokenize(x)) 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /fairseq/data/encoders/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('sentencepiece') 11 | class SentencepieceBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--sentencepiece-vocab', type=str, 17 | help='path to sentencepiece vocab') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | vocab = file_utils.cached_path(args.sentencepiece_vocab) 22 | try: 23 | import sentencepiece as spm 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(vocab) 26 | except ImportError: 27 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 28 | 29 | def encode(self, x: str) -> str: 30 | return ' '.join(self.sp.EncodeAsPieces(x)) 31 | 32 | def decode(self, x: str) -> str: 33 | return x.replace(' ', '').replace('\u2581', ' ').strip() 34 | 35 | def is_beginning_of_word(self, x: str) -> bool: 36 | if x in ['', '', '', '']: 37 | # special elements are always considered beginnings 38 | # HACK: this logic is already present in fairseq/tasks/masked_lm.py 39 | # but these special tokens are also contained in the sentencepiece 40 | # vocabulary which causes duplicate special tokens. This hack makes 41 | # sure that they are all taken into account. 42 | return True 43 | return x.startswith('\u2581') 44 | -------------------------------------------------------------------------------- /fairseq/data/encoders/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | 10 | 11 | @register_tokenizer('space') 12 | class SpaceTokenizer(object): 13 | 14 | def __init__(self, source_lang=None, target_lang=None): 15 | self.space_tok = re.compile(r"\s+") 16 | 17 | def encode(self, x: str) -> str: 18 | return self.space_tok.sub(' ', x) 19 | 20 | def decode(self, x: str) -> str: 21 | return x 22 | -------------------------------------------------------------------------------- /fairseq/data/encoders/subword_nmt_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('subword_nmt') 11 | class SubwordNMTBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--bpe-codes', type=str, 17 | help='path to subword NMT BPE') 18 | parser.add_argument('--bpe-separator', default='@@', 19 | help='BPE separator') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | if args.bpe_codes is None: 24 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt') 25 | codes = file_utils.cached_path(args.bpe_codes) 26 | try: 27 | from subword_nmt import apply_bpe 28 | bpe_parser = apply_bpe.create_parser() 29 | bpe_args = bpe_parser.parse_args([ 30 | '--codes', codes, 31 | '--separator', args.bpe_separator, 32 | ]) 33 | self.bpe = apply_bpe.BPE( 34 | bpe_args.codes, 35 | bpe_args.merges, 36 | bpe_args.separator, 37 | None, 38 | bpe_args.glossaries, 39 | ) 40 | self.bpe_symbol = bpe_args.separator + ' ' 41 | except ImportError: 42 | raise ImportError('Please install subword_nmt with: pip install subword-nmt') 43 | 44 | def encode(self, x: str) -> str: 45 | return self.bpe.process_line(x) 46 | 47 | def decode(self, x: str) -> str: 48 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 49 | -------------------------------------------------------------------------------- /fairseq/data/encoders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from fairseq.data import encoders 8 | 9 | 10 | def get_whole_word_mask(args, dictionary): 11 | bpe = encoders.build_bpe(args) 12 | if bpe is not None: 13 | def is_beginning_of_word(i): 14 | if i < dictionary.nspecial: 15 | # special elements are always considered beginnings 16 | return True 17 | tok = dictionary[i] 18 | if tok.startswith('madeupword'): 19 | return True 20 | try: 21 | return bpe.is_beginning_of_word(tok) 22 | except ValueError: 23 | return True 24 | mask_whole_words = torch.ByteTensor(list( 25 | map(is_beginning_of_word, range(len(dictionary))) 26 | )) 27 | return mask_whole_words 28 | return None 29 | -------------------------------------------------------------------------------- /fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch.utils.data 8 | 9 | 10 | class EpochListening: 11 | """Mixin for receiving updates whenever the epoch increments.""" 12 | def set_epoch(self, epoch): 13 | """Will receive the updated epoch number at the beginning of the epoch. 14 | """ 15 | pass 16 | 17 | 18 | class FairseqDataset(torch.utils.data.Dataset, EpochListening): 19 | """A dataset that provides helpers for batching.""" 20 | 21 | def __getitem__(self, index): 22 | raise NotImplementedError 23 | 24 | def __len__(self): 25 | raise NotImplementedError 26 | 27 | def collater(self, samples): 28 | """Merge a list of samples to form a mini-batch. 29 | 30 | Args: 31 | samples (List[dict]): samples to collate 32 | 33 | Returns: 34 | dict: a mini-batch suitable for forwarding with a Model 35 | """ 36 | raise NotImplementedError 37 | 38 | def num_tokens(self, index): 39 | """Return the number of tokens in a sample. This value is used to 40 | enforce ``--max-tokens`` during batching.""" 41 | raise NotImplementedError 42 | 43 | def size(self, index): 44 | """Return an example's size as a float or tuple. This value is used when 45 | filtering a dataset with ``--max-positions``.""" 46 | raise NotImplementedError 47 | 48 | def ordered_indices(self): 49 | """Return an ordered list of indices. Batches will be constructed based 50 | on this order.""" 51 | return np.arange(len(self)) 52 | 53 | @property 54 | def supports_prefetch(self): 55 | """Whether this dataset supports prefetching.""" 56 | return False 57 | 58 | def attr(self, attr: str, index: int): 59 | return getattr(self, attr, None) 60 | 61 | def prefetch(self, indices): 62 | """Prefetch the data required for this epoch.""" 63 | raise NotImplementedError 64 | 65 | 66 | class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): 67 | """For datasets that need to be read sequentially, usually because the data 68 | is being streamed or otherwise can't be manipulated on a single machine. 69 | """ 70 | 71 | def __iter__(self): 72 | raise NotImplementedError 73 | -------------------------------------------------------------------------------- /fairseq/data/id_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class IdDataset(FairseqDataset): 12 | 13 | def __getitem__(self, index): 14 | return index 15 | 16 | def __len__(self): 17 | return 0 18 | 19 | def collater(self, samples): 20 | return torch.tensor(samples) 21 | -------------------------------------------------------------------------------- /fairseq/data/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary 7 | from .block_pair_dataset import BlockPairDataset 8 | from .masked_lm_dataset import MaskedLMDataset 9 | 10 | __all__ = [ 11 | 'BertDictionary', 12 | 'BlockPairDataset', 13 | 'MaskedLMDataset', 14 | 'MaskedLMDictionary', 15 | ] 16 | -------------------------------------------------------------------------------- /fairseq/data/legacy/masked_lm_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import Dictionary 7 | 8 | 9 | class MaskedLMDictionary(Dictionary): 10 | """ 11 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by 12 | adding the mask symbol. 13 | """ 14 | def __init__( 15 | self, 16 | pad='', 17 | eos='', 18 | unk='', 19 | mask='', 20 | ): 21 | super().__init__(pad, eos, unk) 22 | self.mask_word = mask 23 | self.mask_index = self.add_symbol(mask) 24 | self.nspecial = len(self.symbols) 25 | 26 | def mask(self): 27 | """Helper to get index of mask symbol""" 28 | return self.mask_index 29 | 30 | 31 | class BertDictionary(MaskedLMDictionary): 32 | """ 33 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support 34 | for cls and sep symbols. 35 | """ 36 | def __init__( 37 | self, 38 | pad='', 39 | eos='', 40 | unk='', 41 | mask='', 42 | cls='', 43 | sep='' 44 | ): 45 | super().__init__(pad, eos, unk, mask) 46 | self.cls_word = cls 47 | self.sep_word = sep 48 | self.cls_index = self.add_symbol(cls) 49 | self.sep_index = self.add_symbol(sep) 50 | self.nspecial = len(self.symbols) 51 | 52 | def cls(self): 53 | """Helper to get index of cls symbol""" 54 | return self.cls_index 55 | 56 | def sep(self): 57 | """Helper to get index of sep symbol""" 58 | return self.sep_index 59 | -------------------------------------------------------------------------------- /fairseq/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ListDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, sizes=None): 12 | super().__init__(dataset) 13 | self._sizes = sizes 14 | 15 | def __iter__(self): 16 | for x in self.dataset: 17 | yield x 18 | 19 | def collater(self, samples): 20 | return samples 21 | 22 | @property 23 | def sizes(self): 24 | return self._sizes 25 | 26 | def num_tokens(self, index): 27 | return self.sizes[index] 28 | 29 | def size(self, index): 30 | return self.sizes[index] 31 | 32 | def set_epoch(self, epoch): 33 | pass 34 | -------------------------------------------------------------------------------- /fairseq/data/lru_cache_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import lru_cache 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class LRUCacheDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | 16 | @lru_cache(maxsize=8) 17 | def __getitem__(self, index): 18 | return self.dataset[index] 19 | 20 | @lru_cache(maxsize=8) 21 | def collater(self, samples): 22 | return self.dataset.collater(samples) 23 | -------------------------------------------------------------------------------- /fairseq/data/num_samples_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqDataset 7 | 8 | 9 | class NumSamplesDataset(FairseqDataset): 10 | 11 | def __getitem__(self, index): 12 | return 1 13 | 14 | def __len__(self): 15 | return 0 16 | 17 | def collater(self, samples): 18 | return sum(samples) 19 | -------------------------------------------------------------------------------- /fairseq/data/numel_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class NumelDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, reduce=False): 15 | super().__init__(dataset) 16 | self.reduce = reduce 17 | 18 | def __getitem__(self, index): 19 | item = self.dataset[index] 20 | if torch.is_tensor(item): 21 | return torch.numel(item) 22 | else: 23 | return np.size(item) 24 | 25 | def __len__(self): 26 | return len(self.dataset) 27 | 28 | def collater(self, samples): 29 | if self.reduce: 30 | return sum(samples) 31 | else: 32 | return torch.tensor(samples) 33 | -------------------------------------------------------------------------------- /fairseq/data/offset_tokens_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class OffsetTokensDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, offset): 12 | super().__init__(dataset) 13 | self.offset = offset 14 | 15 | def __getitem__(self, idx): 16 | return self.dataset[idx] + self.offset 17 | -------------------------------------------------------------------------------- /fairseq/data/pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import data_utils 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class PadDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, pad_idx, left_pad): 14 | super().__init__(dataset) 15 | self.pad_idx = pad_idx 16 | self.left_pad = left_pad 17 | 18 | def collater(self, samples): 19 | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) 20 | 21 | 22 | class LeftPadDataset(PadDataset): 23 | 24 | def __init__(self, dataset, pad_idx): 25 | super().__init__(dataset, pad_idx, left_pad=True) 26 | 27 | 28 | class RightPadDataset(PadDataset): 29 | 30 | def __init__(self, dataset, pad_idx): 31 | super().__init__(dataset, pad_idx, left_pad=False) 32 | -------------------------------------------------------------------------------- /fairseq/data/prepend_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): 14 | super().__init__(dataset) 15 | self.prepend_getter = prepend_getter 16 | self.ensure_first_token = ensure_first_token_is 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataset[idx] 20 | is_tuple = isinstance(item, tuple) 21 | src = item[0] if is_tuple else item 22 | 23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token 24 | prepend_idx = self.prepend_getter(self.dataset, idx) 25 | assert isinstance(prepend_idx, int) 26 | src[0] = prepend_idx 27 | item = tuple((src,) + item[1:]) if is_tuple else src 28 | return item 29 | -------------------------------------------------------------------------------- /fairseq/data/prepend_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependTokenDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, token=None): 15 | super().__init__(dataset) 16 | self.token = token 17 | if token is not None: 18 | self._sizes = np.array(dataset.sizes) + 1 19 | else: 20 | self._sizes = dataset.sizes 21 | 22 | def __getitem__(self, idx): 23 | item = self.dataset[idx] 24 | if self.token is not None: 25 | item = torch.cat([item.new([self.token]), item]) 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return self._sizes 31 | 32 | def num_tokens(self, index): 33 | n = self.dataset.num_tokens(index) 34 | if self.token is not None: 35 | n += 1 36 | return n 37 | 38 | def size(self, index): 39 | n = self.dataset.size(index) 40 | if self.token is not None: 41 | n += 1 42 | return n 43 | -------------------------------------------------------------------------------- /fairseq/data/raw_label_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class RawLabelDataset(FairseqDataset): 12 | 13 | def __init__(self, labels): 14 | super().__init__() 15 | self.labels = labels 16 | 17 | def __getitem__(self, index): 18 | return self.labels[index] 19 | 20 | def __len__(self): 21 | return len(self.labels) 22 | 23 | def collater(self, samples): 24 | return torch.tensor(samples) 25 | -------------------------------------------------------------------------------- /fairseq/data/replace_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ReplaceDataset(BaseWrapperDataset): 10 | """Replaces tokens found in the dataset by a specified replacement token 11 | 12 | Args: 13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in 14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token 15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be 16 | as many as the number of objects returned by the underlying dataset __getitem__ method. 17 | """ 18 | 19 | def __init__(self, dataset, replace_map, offsets): 20 | super().__init__(dataset) 21 | assert len(replace_map) > 0 22 | self.replace_map = replace_map 23 | self.offsets = offsets 24 | 25 | def __getitem__(self, index): 26 | item = self.dataset[index] 27 | is_tuple = isinstance(item, tuple) 28 | srcs = item if is_tuple else [item] 29 | 30 | for offset, src in zip(self.offsets, srcs): 31 | for k, v in self.replace_map.items(): 32 | src_off = src[offset:] if offset >= 0 else src[:offset] 33 | src_off.masked_fill_(src_off == k, v) 34 | 35 | item = srcs if is_tuple else srcs[0] 36 | return item 37 | -------------------------------------------------------------------------------- /fairseq/data/roll_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class RollDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, shifts): 14 | super().__init__(dataset) 15 | self.shifts = shifts 16 | 17 | def __getitem__(self, index): 18 | item = self.dataset[index] 19 | return torch.roll(item, self.shifts) 20 | -------------------------------------------------------------------------------- /fairseq/data/sharded_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import os 8 | import random 9 | 10 | from . import BaseWrapperDataset 11 | from fairseq.data import data_utils 12 | 13 | 14 | class ShardedDataset(BaseWrapperDataset): 15 | """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS. 16 | 17 | Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | dictionary, 24 | dataset_impl: str, 25 | path: str, 26 | split: str, 27 | epoch: int, 28 | name: str = None, 29 | combine: bool = False, 30 | seed: int = 0, 31 | ): 32 | self._name = name if name is not None else os.path.basename(path) 33 | num_shards = 0 34 | for i in itertools.count(): 35 | if not os.path.exists(os.path.join(path, "shard" + str(i))): 36 | break 37 | num_shards += 1 38 | 39 | if num_shards > 0 and split == "train": 40 | random.seed(seed ^ epoch) 41 | shard = random.randint(0, num_shards - 1) 42 | split_path = os.path.join(path, "shard" + str(shard), split) 43 | else: 44 | split_path = os.path.join(path, split) 45 | if os.path.isdir(split_path): 46 | split_path = os.path.join(split_path, split) 47 | 48 | dataset = data_utils.load_indexed_dataset( 49 | split_path, dictionary, dataset_impl, combine=combine 50 | ) 51 | if dataset is None: 52 | raise FileNotFoundError( 53 | "Dataset not found: {} ({})".format(split, split_path) 54 | ) 55 | 56 | super().__init__(dataset) 57 | 58 | @property 59 | def name(self): 60 | return self._name 61 | -------------------------------------------------------------------------------- /fairseq/data/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class SortDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, sort_order): 14 | super().__init__(dataset) 15 | if not isinstance(sort_order, (list, tuple)): 16 | sort_order = [sort_order] 17 | self.sort_order = sort_order 18 | 19 | assert all(len(so) == len(dataset) for so in sort_order) 20 | 21 | def ordered_indices(self): 22 | return np.lexsort(self.sort_order) 23 | -------------------------------------------------------------------------------- /fairseq/data/strip_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class StripTokenDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, id_to_strip): 12 | super().__init__(dataset) 13 | self.id_to_strip = id_to_strip 14 | 15 | def __getitem__(self, index): 16 | item = self.dataset[index] 17 | return item[item.ne(self.id_to_strip)] 18 | -------------------------------------------------------------------------------- /fairseq/data/subsample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class SubsampleDataset(BaseWrapperDataset): 12 | """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples 13 | 14 | Args: 15 | dataset (~torch.utils.data.Dataset): dataset to subsample 16 | size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) 17 | """ 18 | 19 | def __init__(self, dataset, size_ratio): 20 | super().__init__(dataset) 21 | assert size_ratio < 1 22 | self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) 23 | self.indices = np.random.choice( 24 | list(range(len(self.dataset))), self.actual_size, replace=False 25 | ) 26 | print( 27 | "subsampled dataset from {} to {} (ratio={})".format( 28 | len(self.dataset), self.actual_size, size_ratio 29 | ) 30 | ) 31 | 32 | def __getitem__(self, index): 33 | return self.dataset[self.indices[index]] 34 | 35 | def __len__(self): 36 | return self.actual_size 37 | 38 | def collater(self, samples): 39 | return self.dataset.collater(samples) 40 | 41 | @property 42 | def sizes(self): 43 | return self.dataset.sizes[self.indices] 44 | 45 | @property 46 | def name(self): 47 | return self.dataset.name 48 | 49 | def num_tokens(self, index): 50 | return self.dataset.num_tokens(self.indices[index]) 51 | 52 | def size(self, index): 53 | return self.dataset.size(self.indices[index]) 54 | 55 | def ordered_indices(self): 56 | """Return an ordered list of indices. Batches will be constructed based 57 | on this order.""" 58 | if self.shuffle: 59 | order = [np.random.permutation(len(self))] 60 | else: 61 | order = [np.arange(len(self))] 62 | order.append(self.sizes) 63 | return np.lexsort(order) 64 | 65 | def prefetch(self, indices): 66 | self.dataset.prefetch(self.indices[indices]) 67 | -------------------------------------------------------------------------------- /fairseq/data/token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/fairseq/data/token_block_utils_fast.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/data/truncate_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class TruncateDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, truncation_length): 14 | super().__init__(dataset) 15 | assert truncation_length is not None 16 | self.truncation_length = truncation_length 17 | self.dataset = dataset 18 | 19 | def __getitem__(self, index): 20 | item = self.dataset[index] 21 | item_len = item.size(0) 22 | if item_len > self.truncation_length: 23 | item = item[:self.truncation_length] 24 | return item 25 | 26 | @property 27 | def sizes(self): 28 | return np.minimum(self.dataset.sizes, self.truncation_length) 29 | 30 | def __len__(self): 31 | return len(self.dataset) 32 | -------------------------------------------------------------------------------- /fairseq/libbleu.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/fairseq/libbleu.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/libnat.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/fairseq/libnat.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/models/bart/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | -------------------------------------------------------------------------------- /fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.models import FairseqEncoder 7 | 8 | 9 | class CompositeEncoder(FairseqEncoder): 10 | """ 11 | A wrapper around a dictionary of :class:`FairseqEncoder` objects. 12 | 13 | We run forward on each encoder and return a dictionary of outputs. The first 14 | encoder's dictionary is used for initialization. 15 | 16 | Args: 17 | encoders (dict): a dictionary of :class:`FairseqEncoder` objects. 18 | """ 19 | 20 | def __init__(self, encoders): 21 | super().__init__(next(iter(encoders.values())).dictionary) 22 | self.encoders = encoders 23 | for key in self.encoders: 24 | self.add_module(key, self.encoders[key]) 25 | 26 | def forward(self, src_tokens, src_lengths): 27 | """ 28 | Args: 29 | src_tokens (LongTensor): tokens in the source language of shape 30 | `(batch, src_len)` 31 | src_lengths (LongTensor): lengths of each source sentence of shape 32 | `(batch)` 33 | 34 | Returns: 35 | dict: 36 | the outputs from each Encoder 37 | """ 38 | encoder_out = {} 39 | for key in self.encoders: 40 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 41 | return encoder_out 42 | 43 | def reorder_encoder_out(self, encoder_out, new_order): 44 | """Reorder encoder output according to new_order.""" 45 | for key in self.encoders: 46 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) 47 | return encoder_out 48 | 49 | def max_positions(self): 50 | return min([self.encoders[key].max_positions() for key in self.encoders]) 51 | 52 | def upgrade_state_dict(self, state_dict): 53 | for key in self.encoders: 54 | self.encoders[key].upgrade_state_dict(state_dict) 55 | return state_dict 56 | -------------------------------------------------------------------------------- /fairseq/models/distributed_fairseq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import inspect 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel 11 | from fairseq.models import BaseFairseqModel 12 | 13 | 14 | def DistributedFairseqModel(args, model): 15 | """ 16 | Wrap a *model* to support distributed data parallel training. 17 | 18 | This is similar to the built-in DistributedDataParallel, but allows 19 | additional configuration of the DistributedDataParallel class to 20 | use, and also provides easier access to the wrapped model by 21 | forwarding requests for missing attributes to the wrapped model. 22 | 23 | Args: 24 | args (argparse.Namespace): fairseq args 25 | model (BaseFairseqModel): model to wrap 26 | """ 27 | # determine which DDP class to extend 28 | assert isinstance(model, nn.Module) 29 | if args.ddp_backend == 'c10d': 30 | ddp_class = nn.parallel.DistributedDataParallel 31 | init_kwargs = dict( 32 | module=model, 33 | device_ids=[args.device_id], 34 | output_device=args.device_id, 35 | broadcast_buffers=False, 36 | bucket_cap_mb=args.bucket_cap_mb, 37 | ) 38 | # Maintain backward compatibility 39 | if 'check_reduction' in inspect.getargspec(ddp_class)[0]: 40 | init_kwargs['check_reduction'] = True 41 | if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]: 42 | init_kwargs['find_unused_parameters'] = args.find_unused_parameters 43 | elif args.ddp_backend == 'no_c10d': 44 | ddp_class = LegacyDistributedDataParallel 45 | init_kwargs = dict( 46 | module=model, 47 | world_size=args.distributed_world_size, 48 | buffer_size=2**28, 49 | ) 50 | else: 51 | raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) 52 | 53 | class _DistributedFairseqModel(ddp_class): 54 | """Extend DistributedDataParallel to check for missing 55 | attributes in the wrapped module.""" 56 | 57 | def __init__(self, *args, **kwargs): 58 | super().__init__(*args, **kwargs) 59 | 60 | def __getattr__(self, name): 61 | wrapped_module = super().__getattr__('module') 62 | if hasattr(wrapped_module, name): 63 | return getattr(wrapped_module, name) 64 | return super().__getattr__(name) 65 | 66 | return _DistributedFairseqModel(**init_kwargs) 67 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class FairseqEncoder(nn.Module): 10 | """Base class for encoders.""" 11 | 12 | def __init__(self, dictionary): 13 | super().__init__() 14 | self.dictionary = dictionary 15 | 16 | def forward(self, src_tokens, src_lengths=None, **kwargs): 17 | """ 18 | Args: 19 | src_tokens (LongTensor): tokens in the source language of shape 20 | `(batch, src_len)` 21 | src_lengths (LongTensor): lengths of each source sentence of shape 22 | `(batch)` 23 | """ 24 | raise NotImplementedError 25 | 26 | def reorder_encoder_out(self, encoder_out, new_order): 27 | """ 28 | Reorder encoder output according to `new_order`. 29 | 30 | Args: 31 | encoder_out: output from the ``forward()`` method 32 | new_order (LongTensor): desired order 33 | 34 | Returns: 35 | `encoder_out` rearranged according to `new_order` 36 | """ 37 | raise NotImplementedError 38 | 39 | def max_positions(self): 40 | """Maximum input length supported by the encoder.""" 41 | return 1e6 # an arbitrary large number 42 | 43 | def upgrade_state_dict(self, state_dict): 44 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 45 | return state_dict 46 | -------------------------------------------------------------------------------- /fairseq/models/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import List, Optional 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | 12 | @torch.jit.script 13 | def script_skip_tensor_list(x: List[Tensor], mask): 14 | res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x] 15 | outputs = [] 16 | for i, t in enumerate(res): 17 | if t.numel() != 0: 18 | outputs.append(t) 19 | else: 20 | outputs.append(x[i]) 21 | return outputs 22 | 23 | 24 | @torch.jit.script 25 | def script_skip_tensor(x: Tensor, mask): 26 | # None case 27 | if x.size(0) == 0: 28 | return x 29 | res = x[mask] if x.size(0) == mask.size(0) else x[:, mask] 30 | if res.numel() == 0: 31 | return x 32 | else: 33 | return res 34 | 35 | 36 | @torch.jit.script 37 | def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): 38 | """ 39 | Expand 2D/3D tensor on dim=1 40 | """ 41 | if x is None: 42 | return None 43 | 44 | assert x.dim() == 2 or x.dim() == 3 45 | assert trg_dim >= x.size(1), (trg_dim, x.size()) 46 | if trg_dim == x.size(1): 47 | return x 48 | 49 | dims = [x.size(0), trg_dim - x.size(1)] 50 | if x.dim() == 3: 51 | dims.append(x.size(2)) 52 | x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1) 53 | 54 | return x 55 | 56 | 57 | @torch.jit.script 58 | def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: 59 | return x if x is not None else y 60 | 61 | 62 | @torch.jit.script 63 | def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]: 64 | """ 65 | Filling tensor x with y at masked positions (dim=0). 66 | """ 67 | if x is None or x.size()[0] == 0 or y is None: 68 | return x 69 | assert x.dim() == y.dim() and mask.size(0) == x.size(0) 70 | assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) 71 | 72 | n_selected = mask.sum() 73 | if n_selected == 0: 74 | return x 75 | assert n_selected == y.size(0) 76 | if n_selected == x.size(0): 77 | return y 78 | 79 | if x.size(1) < y.size(1): 80 | x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) 81 | x[mask] = y 82 | elif x.size(1) > y.size(1): 83 | x[mask] = torch.tensor(padding_idx).type_as(x) 84 | if x.dim() == 2: 85 | x[mask, :y.size(1)] = y 86 | else: 87 | x[mask, :y.size(1), :] = y 88 | else: 89 | x[mask] = y 90 | return x 91 | -------------------------------------------------------------------------------- /fairseq/models/nat/__init__.py: -------------------------------------------------------------------------------- 1 | from .fairseq_nat_model import * 2 | from .nonautoregressive_transformer import * 3 | from .nat_crf_transformer import * 4 | from .iterative_nonautoregressive_transformer import * 5 | from .cmlm_transformer import * 6 | from .levenshtein_transformer import * 7 | from .insertion_transformer import * 8 | -------------------------------------------------------------------------------- /fairseq/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | -------------------------------------------------------------------------------- /fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .adaptive_input import AdaptiveInput 7 | from .adaptive_softmax import AdaptiveSoftmax 8 | from .beamable_mm import BeamableMM 9 | from .character_token_embedder import CharacterTokenEmbedder 10 | from .conv_tbc import ConvTBC 11 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 12 | from .dynamic_convolution import DynamicConv, DynamicConv1dTBC 13 | from .dynamic_crf_layer import DynamicCRF 14 | from .gelu import gelu, gelu_accurate 15 | from .grad_multiply import GradMultiply 16 | from .highway import Highway 17 | from .layer_norm import LayerNorm 18 | from .learned_positional_embedding import LearnedPositionalEmbedding 19 | from .lightweight_convolution import LightweightConv, LightweightConv1dTBC 20 | from .linearized_convolution import LinearizedConvolution 21 | from .logsumexp_moe import LogSumExpMoE 22 | from .mean_pool_gating_network import MeanPoolGatingNetwork 23 | from .multihead_attention import MultiheadAttention 24 | from .positional_embedding import PositionalEmbedding 25 | from .scalar_bias import ScalarBias 26 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 27 | from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer 28 | from .transformer_sentence_encoder import TransformerSentenceEncoder 29 | from .unfold import unfold1d 30 | from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer 31 | from .vggblock import VGGBlock 32 | 33 | __all__ = [ 34 | 'AdaptiveInput', 35 | 'AdaptiveSoftmax', 36 | 'BeamableMM', 37 | 'CharacterTokenEmbedder', 38 | 'ConvTBC', 39 | 'DownsampledMultiHeadAttention', 40 | 'DynamicConv1dTBC', 41 | 'DynamicConv', 42 | 'DynamicCRF', 43 | 'gelu', 44 | 'gelu_accurate', 45 | 'GradMultiply', 46 | 'Highway', 47 | 'LayerNorm', 48 | 'LearnedPositionalEmbedding', 49 | 'LightweightConv1dTBC', 50 | 'LightweightConv', 51 | 'LinearizedConvolution', 52 | 'LogSumExpMoE', 53 | 'MeanPoolGatingNetwork', 54 | 'MultiheadAttention', 55 | 'PositionalEmbedding', 56 | 'ScalarBias', 57 | 'SinusoidalPositionalEmbedding', 58 | 'TransformerSentenceEncoderLayer', 59 | 'TransformerSentenceEncoder', 60 | 'TransformerDecoderLayer', 61 | 'TransformerEncoderLayer', 62 | 'VGGBlock', 63 | 'unfold1d', 64 | ] 65 | -------------------------------------------------------------------------------- /fairseq/modules/adaptive_input.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from typing import List 11 | 12 | 13 | class AdaptiveInput(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | vocab_size: int, 18 | padding_idx: int, 19 | initial_dim: int, 20 | factor: float, 21 | output_dim: int, 22 | cutoff: List[int], 23 | ): 24 | super().__init__() 25 | 26 | if vocab_size > cutoff[-1]: 27 | cutoff = cutoff + [vocab_size] 28 | else: 29 | assert vocab_size == cutoff[ 30 | -1], 'cannot specify cutoff larger than vocab size' 31 | 32 | self.cutoff = cutoff 33 | self.embedding_dim = output_dim 34 | self.padding_idx = padding_idx 35 | 36 | self.embeddings = nn.ModuleList() 37 | for i in range(len(self.cutoff)): 38 | prev = self.cutoff[i - 1] if i > 0 else 0 39 | size = self.cutoff[i] - prev 40 | dim = int(initial_dim // (factor ** i)) 41 | seq = nn.Sequential( 42 | nn.Embedding(size, dim, padding_idx), 43 | nn.Linear(dim, output_dim, bias=False) 44 | ) 45 | self.embeddings.append(seq) 46 | 47 | def init_weights(m): 48 | if isinstance(m, nn.Embedding): 49 | nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) 50 | nn.init.constant_(m.weight[padding_idx], 0) 51 | elif hasattr(m, 'weight'): 52 | nn.init.xavier_uniform_(m.weight) 53 | 54 | self.apply(init_weights) 55 | 56 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 57 | 58 | def weights_for_band(self, band: int): 59 | return self.embeddings[band][0].weight, self.embeddings[band][1].weight 60 | 61 | def forward(self, input: torch.Tensor): 62 | result = self._float_tensor.new(input.shape + (self.embedding_dim,)) 63 | for i in range(len(self.cutoff)): 64 | mask = input.lt(self.cutoff[i]) 65 | if i > 0: 66 | mask.mul_(input.ge(self.cutoff[i - 1])) 67 | chunk_input = input[mask] - self.cutoff[i - 1] 68 | else: 69 | chunk_input = input[mask] 70 | if mask.any(): 71 | result[mask] = self.embeddings[i](chunk_input) 72 | return result 73 | -------------------------------------------------------------------------------- /fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class BeamableMM(nn.Module): 11 | """This module provides an optimized MM for beam decoding with attention. 12 | 13 | It leverage the fact that the source-side of the input is replicated beam 14 | times and the target-side of the input is of width one. This layer speeds up 15 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 16 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 17 | """ 18 | def __init__(self, beam_size=None): 19 | super(BeamableMM, self).__init__() 20 | self.beam_size = beam_size 21 | 22 | def forward(self, input1, input2): 23 | if ( 24 | not self.training and # test mode 25 | self.beam_size is not None and # beam size is set 26 | input1.dim() == 3 and # only support batched input 27 | input1.size(1) == 1 # single time step update 28 | ): 29 | bsz, beam = input1.size(0), self.beam_size 30 | 31 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 32 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 33 | 34 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 35 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 36 | 37 | # use non batched operation if bsz = beam 38 | if input1.size(0) == 1: 39 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 40 | else: 41 | output = input1.bmm(input2) 42 | return output.view(bsz, 1, -1) 43 | else: 44 | return input1.bmm(input2) 45 | 46 | def set_beam_size(self, beam_size): 47 | self.beam_size = beam_size 48 | -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.nn.modules.utils import _single 8 | 9 | 10 | class ConvTBC(torch.nn.Module): 11 | """1D convolution over an input of shape (time x batch x channel) 12 | 13 | The implementation uses gemm to perform the convolution. This implementation 14 | is faster than cuDNN for small kernel sizes. 15 | """ 16 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 17 | super(ConvTBC, self).__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.kernel_size = _single(kernel_size) 21 | self.padding = _single(padding) 22 | 23 | self.weight = torch.nn.Parameter(torch.Tensor( 24 | self.kernel_size[0], in_channels, out_channels)) 25 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 26 | 27 | def forward(self, input): 28 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 29 | 30 | def __repr__(self): 31 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 32 | ', padding={padding}') 33 | if self.bias is None: 34 | s += ', bias=False' 35 | s += ')' 36 | return s.format(name=self.__class__.__name__, **self.__dict__) 37 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamicconv_layer import DynamicconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector dynamicconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector dynamicconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector dynamicconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return dynamicconv_cuda_forward(input, filters, 36 | padding_l); 37 | } 38 | 39 | std::vector dynamicconv_backward( 40 | at::Tensor gradOutput, 41 | int padding_l, 42 | at::Tensor input, 43 | at::Tensor filters) { 44 | 45 | CHECK_INPUT(gradOutput); 46 | CHECK_INPUT(input); 47 | CHECK_INPUT(filters); 48 | 49 | return dynamicconv_cuda_backward(gradOutput, padding_l, 50 | input, filters); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); 55 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); 56 | } 57 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #define SHFL_MASK 0xffffffff 27 | 28 | template 29 | __global__ 30 | void dynamicconv_forward_kernel(const scalar_t* input, 31 | const scalar_t* weight, 32 | int minibatch, 33 | int sequenceLength, 34 | int numFeatures, 35 | int numFiltersInBlock, 36 | int numHeads, 37 | scalar_t* output); 38 | 39 | template 40 | __global__ 41 | void dynamicconv_backward_kernel( 42 | const scalar_t* gradOutput, // B * C * T 43 | const scalar_t* input, // B * C * T 44 | const scalar_t* weight, 45 | int minibatch, 46 | int sequenceLength, 47 | int numFeatures, 48 | int numFiltersInBlock, 49 | int numHeads, 50 | scalar_t* gradWeight, 51 | scalar_t* gradInput); // B * H * k * T 52 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector dynamicconv_cpu_forward( 5 | float* input, 6 | float* filters, 7 | int padding_l); 8 | 9 | std::vector dynamicconv_cpu_backward( 10 | float* gradOutput, 11 | int padding_l, 12 | float* input, 13 | float* filters); 14 | 15 | std::vector dynamicconv_forward( 16 | float* input, 17 | float* filters, 18 | int padding_l) { 19 | 20 | return dynamicconv_cpu_forward(input, filters, padding_l); 21 | } 22 | 23 | std::vector dynamicconv_backward( 24 | float* gradOutput, 25 | int padding_l, 26 | float* input, 27 | float* filters) { 28 | 29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); 34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); 35 | } 36 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='dynamicconv_layer', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='dynamicconv_cuda', 15 | sources=[ 16 | 'dynamicconv_cuda.cpp', 17 | 'dynamicconv_cuda_kernel.cu', 18 | ], 19 | ), 20 | ], 21 | cmdclass={ 22 | 'build_ext': BuildExtension 23 | }) 24 | -------------------------------------------------------------------------------- /fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | 14 | 15 | def gelu_accurate(x): 16 | if not hasattr(gelu_accurate, "_a"): 17 | gelu_accurate._a = math.sqrt(2 / math.pi) 18 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 19 | 20 | 21 | def gelu(x: torch.Tensor) -> torch.Tensor: 22 | if hasattr(torch.nn.functional, 'gelu'): 23 | return torch.nn.functional.gelu(x.float()).type_as(x) 24 | else: 25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 26 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class GradMultiply(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.scale = scale 13 | res = x.new(x) 14 | return res 15 | 16 | @staticmethod 17 | def backward(ctx, grad): 18 | return grad * ctx.scale, None 19 | -------------------------------------------------------------------------------- /fairseq/modules/highway.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from torch import nn 9 | 10 | 11 | class Highway(torch.nn.Module): 12 | """ 13 | A `Highway layer `_. 14 | Adopted from the AllenNLP implementation. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | input_dim: int, 20 | num_layers: int = 1 21 | ): 22 | super(Highway, self).__init__() 23 | self.input_dim = input_dim 24 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) 25 | for _ in range(num_layers)]) 26 | self.activation = nn.ReLU() 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | for layer in self.layers: 32 | # As per comment in AllenNLP: 33 | # We should bias the highway layer to just carry its input forward. We do that by 34 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 35 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 36 | # of the bias vector in each Linear layer. 37 | nn.init.constant_(layer.bias[self.input_dim:], 1) 38 | 39 | nn.init.constant_(layer.bias[:self.input_dim], 0) 40 | nn.init.xavier_normal_(layer.weight) 41 | 42 | def forward( 43 | self, 44 | x: torch.Tensor 45 | ): 46 | for layer in self.layers: 47 | projection = layer(x) 48 | proj_x, gate = projection.chunk(2, dim=-1) 49 | proj_x = self.activation(proj_x) 50 | gate = torch.sigmoid(gate) 51 | x = gate * x + (gate.new_tensor([1]) - gate) * proj_x 52 | return x 53 | -------------------------------------------------------------------------------- /fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 10 | if not export and torch.cuda.is_available(): 11 | try: 12 | from apex.normalization import FusedLayerNorm 13 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 14 | except ImportError: 15 | pass 16 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 17 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from fairseq import utils1 as utils 9 | 10 | 11 | class LearnedPositionalEmbedding(nn.Embedding): 12 | """ 13 | This module learns positional embeddings up to a fixed maximum size. 14 | Padding ids are ignored by either offsetting based on padding_idx 15 | or by setting padding_idx to None and ensuring that the appropriate 16 | position ids are passed to the forward function. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | num_embeddings: int, 22 | embedding_dim: int, 23 | padding_idx: int, 24 | ): 25 | super().__init__(num_embeddings, embedding_dim, padding_idx) 26 | self.onnx_trace = False 27 | 28 | def forward(self, input, incremental_state=None, positions=None): 29 | """Input is expected to be of size [bsz x seqlen].""" 30 | assert ( 31 | (positions is None) or (self.padding_idx is None) 32 | ), "If positions is pre-computed then padding_idx should not be set." 33 | 34 | if positions is None: 35 | if incremental_state is not None: 36 | # positions is the same for every token when decoding a single step 37 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 38 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) 39 | else: 40 | positions = utils.make_positions( 41 | input, self.padding_idx, onnx_trace=self.onnx_trace, 42 | ) 43 | return super().forward(positions) 44 | 45 | def max_positions(self): 46 | """Maximum number of supported positions.""" 47 | if self.padding_idx is not None: 48 | return self.num_embeddings - self.padding_idx - 1 49 | else: 50 | return self.num_embeddings 51 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lightconv_layer import LightconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector lightconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector lightconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector lightconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return lightconv_cuda_forward(input, filters, padding_l); 36 | } 37 | 38 | std::vector lightconv_backward( 39 | at::Tensor gradOutput, 40 | int padding_l, 41 | at::Tensor input, 42 | at::Tensor filters) { 43 | 44 | CHECK_INPUT(gradOutput); 45 | CHECK_INPUT(input); 46 | CHECK_INPUT(filters); 47 | 48 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); 53 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | 24 | #define SHFL_MASK 0xffffffff 25 | 26 | template 27 | __global__ 28 | void lightconv_forward_kernel(const scalar_t* input, 29 | const scalar_t* filters, 30 | int minibatch, int sequenceLength, 31 | int numFeatures, int numFiltersInBlock, 32 | scalar_t* output); 33 | 34 | template 35 | __global__ 36 | void lightconv_grad_wrt_input_kernel( 37 | const scalar_t* input, 38 | const scalar_t* filters, 39 | int minibatch, 40 | int sequenceLength, 41 | int numFeatures, 42 | int numFiltersInBlock, 43 | scalar_t* output); 44 | 45 | template 46 | __global__ 47 | void lightconv_grad_wrt_weights_firstpass_short_kernel( 48 | const scalar_t* input, 49 | const scalar_t* gradInput, 50 | int minibatch, 51 | int sequenceLength, 52 | int numFeatures, 53 | int numFiltersInBlock, 54 | int numHeads, 55 | float* output); 56 | 57 | template 58 | __global__ 59 | void lightconv_grad_wrt_weights_secondpass_short_kernel( 60 | const float* input, 61 | const int minibatch, 62 | const int numFiltersInBlock, 63 | scalar_t* output); 64 | 65 | template 66 | __global__ 67 | void lightconv_grad_wrt_weights_firstpass_kernel( 68 | const scalar_t* input, 69 | const scalar_t* gradInput, 70 | int minibatch, 71 | int sequenceLength, 72 | int numFeatures, 73 | int numFiltersInBlock, 74 | float* output); 75 | 76 | template 77 | __global__ 78 | void lightconv_grad_wrt_weights_secondpass_kernel( 79 | const float* input, 80 | const int minibatch, 81 | const int numFiltersInBlock, 82 | scalar_t* output); 83 | 84 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='lightconv_layer', 12 | ext_modules=[ 13 | CUDAExtension('lightconv_cuda', [ 14 | 'lightconv_cuda.cpp', 15 | 'lightconv_cuda_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }) 21 | -------------------------------------------------------------------------------- /fairseq/modules/logsumexp_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class LogSumExpMoE(torch.autograd.Function): 10 | """Standard LogSumExp forward pass, but use *posterior* for the backward. 11 | 12 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" 13 | (Shen et al., 2019) `_. 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, logp, posterior, dim=-1): 18 | ctx.save_for_backward(posterior) 19 | ctx.dim = dim 20 | return torch.logsumexp(logp, dim=dim) 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | posterior, = ctx.saved_tensors 25 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior 26 | return grad_logp, None, None 27 | -------------------------------------------------------------------------------- /fairseq/modules/mean_pool_gating_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | class MeanPoolGatingNetwork(torch.nn.Module): 11 | """A simple mean-pooling gating network for selecting experts. 12 | 13 | This module applies mean pooling over an encoder's output and returns 14 | reponsibilities for each expert. The encoder format is expected to match 15 | :class:`fairseq.models.transformer.TransformerEncoder`. 16 | """ 17 | 18 | def __init__(self, embed_dim, num_experts, dropout=None): 19 | super().__init__() 20 | self.embed_dim = embed_dim 21 | self.num_experts = num_experts 22 | 23 | self.fc1 = torch.nn.Linear(embed_dim, embed_dim) 24 | self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None 25 | self.fc2 = torch.nn.Linear(embed_dim, num_experts) 26 | 27 | def forward(self, encoder_out): 28 | if not ( 29 | hasattr(encoder_out, 'encoder_out') 30 | and hasattr(encoder_out, 'encoder_padding_mask') 31 | and encoder_out.encoder_out.size(2) == self.embed_dim 32 | ): 33 | raise ValueError('Unexpected format for encoder_out') 34 | 35 | # mean pooling over time 36 | encoder_padding_mask = encoder_out.encoder_padding_mask # B x T 37 | encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C 38 | if encoder_padding_mask is not None: 39 | encoder_out = encoder_out.clone() # required because of transpose above 40 | encoder_out[encoder_padding_mask] = 0 41 | ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True) 42 | x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out) 43 | else: 44 | x = torch.mean(encoder_out, dim=1) 45 | 46 | x = torch.tanh(self.fc1(x)) 47 | if self.dropout is not None: 48 | x = self.dropout(x) 49 | x = self.fc2(x) 50 | return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x) 51 | -------------------------------------------------------------------------------- /fairseq/modules/positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from .learned_positional_embedding import LearnedPositionalEmbedding 9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 10 | 11 | 12 | def PositionalEmbedding( 13 | num_embeddings: int, 14 | embedding_dim: int, 15 | padding_idx: int, 16 | learned: bool = False, 17 | ): 18 | if learned: 19 | # if padding_idx is specified then offset the embedding ids by 20 | # this index and adjust num_embeddings appropriately 21 | # TODO: The right place for this offset would be inside 22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation. 23 | if padding_idx is not None: 24 | num_embeddings = num_embeddings + padding_idx + 1 25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 27 | if padding_idx is not None: 28 | nn.init.constant_(m.weight[padding_idx], 0) 29 | else: 30 | m = SinusoidalPositionalEmbedding( 31 | embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, 32 | ) 33 | return m 34 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import torch 8 | 9 | 10 | class ScalarBias(torch.autograd.Function): 11 | """ 12 | Adds a vector of scalars, used in self-attention mechanism to allow 13 | the model to optionally attend to this vector instead of the past 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, input, dim, bias_init): 18 | size = list(input.size()) 19 | size[dim] += 1 20 | output = input.new(*size).fill_(bias_init) 21 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 22 | ctx.dim = dim 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad): 27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 28 | 29 | 30 | def scalar_bias(input, dim, bias_init=0): 31 | return ScalarBias.apply(input, dim, bias_init) 32 | -------------------------------------------------------------------------------- /fairseq/modules/sparse_transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.modules import TransformerSentenceEncoderLayer 7 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention 8 | 9 | 10 | class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): 11 | """ 12 | Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) 13 | """ 14 | 15 | def __init__( 16 | self, 17 | embedding_dim: int = 768, 18 | ffn_embedding_dim: int = 3072, 19 | num_attention_heads: int = 8, 20 | dropout: float = 0.1, 21 | attention_dropout: float = 0.1, 22 | activation_dropout: float = 0.1, 23 | activation_fn: str = 'relu', 24 | add_bias_kv: bool = False, 25 | add_zero_attn: bool = False, 26 | export: bool = False, 27 | is_bidirectional: bool = True, 28 | stride: int = 32, 29 | expressivity: int = 8, 30 | ) -> None: 31 | 32 | super().__init__( 33 | embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, 34 | attention_dropout, activation_dropout, activation_fn, add_bias_kv, 35 | add_zero_attn, export 36 | ) 37 | 38 | self.self_attn = SparseMultiheadAttention( 39 | self.embedding_dim, 40 | num_attention_heads, 41 | dropout=attention_dropout, 42 | add_bias_kv=add_bias_kv, 43 | add_zero_attn=add_zero_attn, 44 | self_attention=True, 45 | is_bidirectional=is_bidirectional, 46 | stride=stride, 47 | expressivity=expressivity, 48 | ) 49 | -------------------------------------------------------------------------------- /fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 10 | '''unfold T x B x C to T x B x C x K''' 11 | if kernel_size > 1: 12 | T, B, C = x.size() 13 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 14 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 15 | else: 16 | x = x.unsqueeze(3) 17 | return x 18 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer 11 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 12 | from fairseq.optim.bmuf import FairseqBMUF # noqa 13 | 14 | 15 | __all__ = [ 16 | 'FairseqOptimizer', 17 | 'FP16Optimizer', 18 | 'MemoryEfficientFP16Optimizer', 19 | ] 20 | 21 | 22 | build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( 23 | '--optimizer', 24 | base_class=FairseqOptimizer, 25 | default='nag', 26 | ) 27 | 28 | 29 | # automatically import any Python files in the optim/ directory 30 | for file in os.listdir(os.path.dirname(__file__)): 31 | if file.endswith('.py') and not file.startswith('_'): 32 | module = file[:file.find('.py')] 33 | importlib.import_module('fairseq.optim.' + module) 34 | -------------------------------------------------------------------------------- /fairseq/optim/adadelta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import FairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer('adadelta') 12 | class Adadelta(FairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', 22 | help='coefficient used for computing a running average of squared gradients') 23 | parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', 24 | help='term added to the denominator to improve numerical stability') 25 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 26 | help='weight decay') 27 | parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') 28 | # fmt: on 29 | 30 | @property 31 | def optimizer_config(self): 32 | """ 33 | Return a kwarg dictionary that will be used to override optimizer 34 | args stored in checkpoints. This allows us to load a checkpoint and 35 | resume training using a different set of optimizer args, e.g., with a 36 | different learning rate. 37 | """ 38 | return { 39 | 'lr': self.args.lr[0], 40 | 'rho': self.args.adadelta_rho, 41 | 'eps': self.args.adadelta_eps, 42 | 'weight_decay': self.args.weight_decay, 43 | } 44 | -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import FairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer('adagrad') 12 | class Adagrad(FairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 22 | help='weight decay') 23 | # fmt: on 24 | 25 | @property 26 | def optimizer_config(self): 27 | """ 28 | Return a kwarg dictionary that will be used to override optimizer 29 | args stored in checkpoints. This allows us to load a checkpoint and 30 | resume training using a different set of optimizer args, e.g., with a 31 | different learning rate. 32 | """ 33 | return { 34 | 'lr': self.args.lr[0], 35 | 'weight_decay': self.args.weight_decay, 36 | } 37 | -------------------------------------------------------------------------------- /fairseq/optim/fused_lamb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.optim import FairseqOptimizer, register_optimizer 7 | 8 | 9 | @register_optimizer('lamb') 10 | class FairseqLAMB(FairseqOptimizer): 11 | """LAMB optimizer.""" 12 | 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | try: 16 | from apex.optimizers import FusedLAMB 17 | self._optimizer = FusedLAMB(params, **self.optimizer_config) 18 | except ImportError: 19 | raise ImportError('Please install apex to use LAMB optimizer') 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add optimizer-specific arguments to the parser.""" 24 | # fmt: off 25 | parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', 26 | help='betas for LAMB optimizer') 27 | parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', 28 | help='epsilon for LAMB optimizer') 29 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 30 | help='weight decay') 31 | # fmt: on 32 | 33 | @property 34 | def optimizer_config(self): 35 | """ 36 | Return a kwarg dictionary that will be used to override optimizer 37 | args stored in checkpoints. This allows us to load a checkpoint and 38 | resume training using a different set of optimizer args, e.g., with a 39 | different learning rate. 40 | """ 41 | return { 42 | 'lr': self.args.lr[0], 43 | 'betas': eval(self.args.lamb_betas), 44 | 'eps': self.args.lamb_eps, 45 | 'weight_decay': self.args.weight_decay, 46 | } 47 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler 11 | 12 | 13 | build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( 14 | '--lr-scheduler', 15 | base_class=FairseqLRScheduler, 16 | default='fixed', 17 | ) 18 | 19 | # automatically import any Python files in the optim/lr_scheduler/ directory 20 | for file in os.listdir(os.path.dirname(__file__)): 21 | if file.endswith('.py') and not file.startswith('_'): 22 | module = file[:file.find('.py')] 23 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 24 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .. import FairseqOptimizer 7 | 8 | 9 | class FairseqLRScheduler(object): 10 | 11 | def __init__(self, args, optimizer): 12 | super().__init__() 13 | if not isinstance(optimizer, FairseqOptimizer): 14 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 15 | self.args = args 16 | self.optimizer = optimizer 17 | self.best = None 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add arguments to the parser for this LR scheduler.""" 22 | pass 23 | 24 | def state_dict(self): 25 | """Return the LR scheduler state dict.""" 26 | return {'best': self.best} 27 | 28 | def load_state_dict(self, state_dict): 29 | """Load an LR scheduler state dict.""" 30 | self.best = state_dict['best'] 31 | 32 | def step(self, epoch, val_loss=None): 33 | """Update the learning rate at the end of the given epoch.""" 34 | if val_loss is not None: 35 | if self.best is None: 36 | self.best = val_loss 37 | else: 38 | self.best = min(self.best, val_loss) 39 | 40 | def step_update(self, num_updates): 41 | """Update the learning rate after each update.""" 42 | return self.optimizer.get_lr() 43 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqLRScheduler, register_lr_scheduler 7 | 8 | 9 | @register_lr_scheduler('fixed') 10 | class FixedSchedule(FairseqLRScheduler): 11 | """Decay the LR on a fixed schedule.""" 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__(args, optimizer) 15 | 16 | # set defaults 17 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 18 | 19 | self.lr = args.lr[0] 20 | if args.warmup_updates > 0: 21 | self.warmup_factor = 1. / args.warmup_updates 22 | else: 23 | self.warmup_factor = 1 24 | 25 | @staticmethod 26 | def add_args(parser): 27 | """Add arguments to the parser for this LR scheduler.""" 28 | # fmt: off 29 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 30 | help='force annealing at specified epoch') 31 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 32 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)') 33 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 34 | help='warmup the learning rate linearly for the first N updates') 35 | # fmt: on 36 | 37 | def get_next_lr(self, epoch): 38 | lrs = self.args.lr 39 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 40 | # use fixed LR schedule 41 | next_lr = lrs[min(epoch, len(lrs) - 1)] 42 | else: 43 | # annneal based on lr_shrink 44 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 45 | return next_lr 46 | 47 | def step(self, epoch, val_loss=None): 48 | """Update the learning rate at the end of the given epoch.""" 49 | super().step(epoch, val_loss) 50 | self.lr = self.get_next_lr(epoch) 51 | self.optimizer.set_lr(self.warmup_factor * self.lr) 52 | return self.optimizer.get_lr() 53 | 54 | def step_update(self, num_updates): 55 | """Update the learning rate after each update.""" 56 | if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: 57 | self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) 58 | self.optimizer.set_lr(self.warmup_factor * self.lr) 59 | return self.optimizer.get_lr() 60 | -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import FairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer('sgd') 12 | class SGD(FairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M', 22 | help='momentum factor') 23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 24 | help='weight decay') 25 | # fmt: on 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | 'lr': self.args.lr[0], 37 | 'momentum': self.args.momentum, 38 | 'weight_decay': self.args.weight_decay, 39 | } 40 | -------------------------------------------------------------------------------- /fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import pdb 9 | import sys 10 | 11 | 12 | __all__ = ['set_trace'] 13 | 14 | 15 | _stdin = [None] 16 | _stdin_lock = multiprocessing.Lock() 17 | try: 18 | _stdin_fd = sys.stdin.fileno() 19 | except Exception: 20 | _stdin_fd = None 21 | 22 | 23 | class MultiprocessingPdb(pdb.Pdb): 24 | """A Pdb wrapper that works in a multiprocessing environment. 25 | 26 | Usage: `from fairseq import pdb; pdb.set_trace()` 27 | """ 28 | 29 | def __init__(self): 30 | pdb.Pdb.__init__(self, nosigint=True) 31 | 32 | def _cmdloop(self): 33 | stdin_bak = sys.stdin 34 | with _stdin_lock: 35 | try: 36 | if _stdin_fd is not None: 37 | if not _stdin[0]: 38 | _stdin[0] = os.fdopen(_stdin_fd) 39 | sys.stdin = _stdin[0] 40 | self.cmdloop() 41 | finally: 42 | sys.stdin = stdin_bak 43 | 44 | 45 | def set_trace(): 46 | pdb = MultiprocessingPdb() 47 | pdb.set_trace(sys._getframe().f_back) 48 | -------------------------------------------------------------------------------- /fairseq/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | 8 | 9 | REGISTRIES = {} 10 | 11 | 12 | def setup_registry( 13 | registry_name: str, 14 | base_class=None, 15 | default=None, 16 | ): 17 | assert registry_name.startswith('--') 18 | registry_name = registry_name[2:].replace('-', '_') 19 | 20 | REGISTRY = {} 21 | REGISTRY_CLASS_NAMES = set() 22 | 23 | # maintain a registry of all registries 24 | if registry_name in REGISTRIES: 25 | return # registry already exists 26 | REGISTRIES[registry_name] = { 27 | 'registry': REGISTRY, 28 | 'default': default, 29 | } 30 | 31 | def build_x(args, *extra_args, **extra_kwargs): 32 | choice = getattr(args, registry_name, None) 33 | if choice is None: 34 | return None 35 | cls = REGISTRY[choice] 36 | if hasattr(cls, 'build_' + registry_name): 37 | builder = getattr(cls, 'build_' + registry_name) 38 | else: 39 | builder = cls 40 | set_defaults(args, cls) 41 | return builder(args, *extra_args, **extra_kwargs) 42 | 43 | def register_x(name): 44 | 45 | def register_x_cls(cls): 46 | if name in REGISTRY: 47 | raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name)) 48 | if cls.__name__ in REGISTRY_CLASS_NAMES: 49 | raise ValueError( 50 | 'Cannot register {} with duplicate class name ({})'.format( 51 | registry_name, cls.__name__, 52 | ) 53 | ) 54 | if base_class is not None and not issubclass(cls, base_class): 55 | raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__)) 56 | REGISTRY[name] = cls 57 | REGISTRY_CLASS_NAMES.add(cls.__name__) 58 | return cls 59 | 60 | return register_x_cls 61 | 62 | return build_x, register_x, REGISTRY 63 | 64 | 65 | def set_defaults(args, cls): 66 | """Helper to set default arguments based on *add_args*.""" 67 | if not hasattr(cls, 'add_args'): 68 | return 69 | parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False) 70 | cls.add_args(parser) 71 | # copied from argparse.py: 72 | defaults = argparse.Namespace() 73 | for action in parser._actions: 74 | if action.dest is not argparse.SUPPRESS: 75 | if not hasattr(defaults, action.dest): 76 | if action.default is not argparse.SUPPRESS: 77 | setattr(defaults, action.dest, action.default) 78 | for key, default_value in vars(defaults).items(): 79 | if not hasattr(args, key): 80 | setattr(args, key, default_value) 81 | -------------------------------------------------------------------------------- /fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import importlib 8 | import os 9 | 10 | from .fairseq_task import FairseqTask 11 | 12 | TASK_REGISTRY = {} 13 | TASK_CLASS_NAMES = set() 14 | 15 | 16 | def setup_task(args, **kwargs): 17 | return TASK_REGISTRY[args.task].setup_task(args, **kwargs) 18 | 19 | 20 | def register_task(name): 21 | """ 22 | New tasks can be added to fairseq with the 23 | :func:`~fairseq.tasks.register_task` function decorator. 24 | 25 | For example:: 26 | 27 | @register_task('classification') 28 | class ClassificationTask(FairseqTask): 29 | (...) 30 | 31 | .. note:: 32 | 33 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 34 | interface. 35 | 36 | Please see the 37 | 38 | Args: 39 | name (str): the name of the task 40 | """ 41 | 42 | def register_task_cls(cls): 43 | if name in TASK_REGISTRY: 44 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 45 | if not issubclass(cls, FairseqTask): 46 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 47 | if cls.__name__ in TASK_CLASS_NAMES: 48 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 49 | TASK_REGISTRY[name] = cls 50 | TASK_CLASS_NAMES.add(cls.__name__) 51 | return cls 52 | 53 | return register_task_cls 54 | 55 | 56 | # automatically import any Python files in the tasks/ directory 57 | for file in os.listdir(os.path.dirname(__file__)): 58 | if file.endswith('.py') and not file.startswith('_'): 59 | task_name = file[:file.find('.py')] 60 | importlib.import_module('fairseq.tasks.' + task_name) 61 | 62 | # expose `task_parser` for sphinx 63 | if task_name in TASK_REGISTRY: 64 | parser = argparse.ArgumentParser(add_help=False) 65 | group_task = parser.add_argument_group('Task name') 66 | # fmt: off 67 | group_task.add_argument('--task', metavar=task_name, 68 | help='Enable this task with: ``--task=' + task_name + '``') 69 | # fmt: on 70 | group_args = parser.add_argument_group('Additional command-line arguments') 71 | TASK_REGISTRY[task_name].add_args(group_args) 72 | globals()[task_name + '_parser'] = parser 73 | 74 | 75 | def get_task(name): 76 | return TASK_REGISTRY[name] 77 | -------------------------------------------------------------------------------- /fairseq/tasks/audio_pretraining.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | 8 | from fairseq.data import FileAudioDataset 9 | from . import FairseqTask, register_task 10 | 11 | 12 | @register_task('audio_pretraining') 13 | class AudioPretrainingTask(FairseqTask): 14 | """ 15 | 16 | """ 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add task-specific arguments to the parser.""" 21 | parser.add_argument('data', help='path to data directory') 22 | parser.add_argument('--sample-rate', default=16000, type=int, 23 | help='target sample rate. audio files will be up/down sampled to this rate') 24 | parser.add_argument('--max-sample-size', default=None, type=int, 25 | help='max sample size to crop to for batching. default = min sample length') 26 | parser.add_argument('--min-sample-size', default=None, type=int, 27 | help='min sample size to crop to for batching. default = same as --max-sample-size') 28 | 29 | def __init__(self, args): 30 | super().__init__(args) 31 | 32 | @classmethod 33 | def setup_task(cls, args, **kwargs): 34 | """Setup the task (e.g., load dictionaries). 35 | 36 | Args: 37 | args (argparse.Namespace): parsed command-line arguments 38 | """ 39 | return cls(args) 40 | 41 | def load_dataset(self, split, **kwargs): 42 | """Load a given dataset split. 43 | 44 | Args: 45 | split (str): name of the split (e.g., train, valid, test) 46 | """ 47 | 48 | manifest = os.path.join(self.args.data, '{}.tsv'.format(split)) 49 | self.datasets[split] = FileAudioDataset(manifest, 50 | sample_rate=self.args.sample_rate, 51 | max_sample_size=self.args.max_sample_size, 52 | min_sample_size=self.args.min_sample_size) 53 | 54 | @property 55 | def target_dictionary(self): 56 | """Return the :class:`~fairseq.data.Dictionary` for the language 57 | model.""" 58 | return None 59 | -------------------------------------------------------------------------------- /fairseq/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary 7 | from fairseq.tasks.translation import TranslationTask 8 | 9 | from . import register_task 10 | 11 | 12 | @register_task("translation_from_pretrained_xlm") 13 | class TranslationFromPretrainedXLMTask(TranslationTask): 14 | """ 15 | Same as TranslationTask except use the MaskedLMDictionary class so that 16 | we can load data that was binarized with the MaskedLMDictionary class. 17 | 18 | This task should be used for the entire training pipeline when we want to 19 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 20 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 21 | of that trained model. 22 | """ 23 | 24 | @classmethod 25 | def load_dictionary(cls, filename): 26 | """Load the masked LM dictionary from the filename 27 | 28 | Args: 29 | filename (str): the filename 30 | """ 31 | return MaskedLMDictionary.load(filename) 32 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | SPACE_NORMALIZER = re.compile(r"\s+") 9 | 10 | 11 | def tokenize_line(line): 12 | line = SPACE_NORMALIZER.sub(" ", line) 13 | line = line.strip() 14 | return line.split() 15 | -------------------------------------------------------------------------------- /fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /fairseq_cli/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fairseq.models.bart import BARTModel 3 | import os 4 | import time 5 | import numpy as np 6 | os.environ['CUDA_VISIBLE_DEVICES']="1" 7 | 8 | bart = BARTModel.from_pretrained('checkpoint-similenew/',checkpoint_file='checkpoint_best.pt',data_name_or_path='similenew') 9 | 10 | bart.cuda() 11 | bart.eval() 12 | 13 | np.random.seed(4) 14 | torch.manual_seed(4) 15 | 16 | count = 1 17 | bsz = 1 18 | t = 0.7 19 | for val in [5]: 20 | with open('literal.txt') as source, open('simile.hypo', 'w') as fout: 21 | sline = source.readline().strip() 22 | slines = [sline] 23 | for sline in source: 24 | if count % bsz == 0: 25 | with torch.no_grad(): 26 | hypotheses_batch = bart.sample(slines, sampling=True, sampling_topk=val ,temperature=t ,lenpen=2.0, max_len_b=30, min_len=7, no_repeat_ngram_size=3) 27 | for hypothesis in hypotheses_batch: 28 | fout.write(hypothesis.replace('\n','') + '\n') 29 | fout.flush() 30 | slines = [] 31 | 32 | slines.append(sline.strip()) 33 | count += 1 34 | if slines != []: 35 | hypotheses_batch = bart.sample(slines, sampling=True, sampling_topk=val ,temperature=t ,lenpen=2.0, max_len_b=30, min_len=7, no_repeat_ngram_size=3) 36 | for hypothesis in hypotheses_batch: 37 | fout.write(hypothesis.replace('\n','') + '\n') 38 | fout.flush() 39 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=20000 2 | WARMUP_UPDATES=500 3 | LR=3e-05 4 | MAX_TOKENS=1024 5 | UPDATE_FREQ=16 6 | BART_PATH=/nas/home/fairseq/bart.large/model.pt 7 | 8 | python train.py simile\ 9 | --restore-file $BART_PATH \ 10 | --max-tokens $MAX_TOKENS \ 11 | --task translation \ 12 | --source-lang source --target-lang target \ 13 | --truncate-source \ 14 | --truncate-target \ 15 | --layernorm-embedding \ 16 | --share-all-embeddings \ 17 | --share-decoder-input-output-embed \ 18 | --reset-optimizer --reset-dataloader --reset-meters \ 19 | --required-batch-size-multiple 1 \ 20 | --arch bart_large \ 21 | --criterion label_smoothed_cross_entropy \ 22 | --label-smoothing 0.1 \ 23 | --dropout 0.1 --attention-dropout 0.1 \ 24 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ 25 | --clip-norm 0.1 \ 26 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 27 | --memory-efficient-fp16 --update-freq $UPDATE_FREQ \ 28 | --save-dir "checkpoint-simile" \ 29 | --ddp-backend=no_c10d \ 30 | --skip-invalid-size-inputs-valid-test \ 31 | --find-unused-parameters; 32 | -------------------------------------------------------------------------------- /gen2.source: -------------------------------------------------------------------------------- 1 | They were well-cooked, and actually tasted nice 2 | It was obscene, but she was drawn to it, fascinated 3 | Rory looked at her as if she were crazy 4 | My eyes teared up and I could feel my face turning red 5 | She smirked when she saw me looking at it, eyes huge and mouth open 6 | I made quick work of one with my knife and slipped inside undetected 7 | He adjusted himself groggily, but upon seeing me, he sprang upright 8 | I needed to keep my cool, but the King surely wasn't making it easy 9 | He was executed the next day, and Sparrow was declared dead 10 | Then there weren’t any more parties as the house became silent 11 | The man’s pale skin appeared human, but appearances were extremely deceptive 12 | I knew there was pain but this was intense 13 | It wouldn't always be noticeable, but it was always palpable 14 | But when we made contact with them they were very hesitant 15 | He had a wide smile, but his eyes were cold, yet intense 16 | The knocking on the door was more assertive now, more confident 17 | They put up a brave face as governments become increasingly authoritarian 18 | An ordinary citizen coming to power in this way is unprecedented 19 | Yet as the days would pass the user would become increasingly distressed 20 | After decades of drought the Western United states is declared uninhabitable 21 | You both have something important to say, but time is short 22 | You're the only one who notices, but you're remarkably unfazed 23 | The child's job is to attract followers to make their deity more powerful 24 | Everyone is a robot subtly testing you, but you are starting to get suspicious 25 | You have no idea why people suddenly have become sick 26 | Their nuclear program however, has fallen flat 27 | You can freely switch between them but the rest go completely lifeless 28 | The suspect on your biggest case is squeaky clean, but you know they're guilty 29 | Several thousand years in the future, humans evolved to be emotionless 30 | Amen has fascinating joyful dreams every night, while his real life is unbearable 31 | All contact with the surface has been lost, and the radiation topside is lethal 32 | You are at a business dinner with your boss when your phone rings out loud 33 | A global catastrophe causes all major technology to be rendered useless 34 | We have learned that the ending of the Earth is imminent and unavoidable 35 | On your way to work, you notice the world is unnaturally calm 36 | It was a moonless nights, the air was still and the crickets were silent 37 | Reincarnation has been proven, memories are now retrievable 38 | We're starting to discover life is oddly permanent 39 | In the near future e-sports have become extremely popular and equally profitable 40 | Your dog is very excited to see you and acting quite peculiar 41 | Telling lies to the young is wrong -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fairseq.models.bart import BARTModel 3 | import os 4 | import time 5 | import numpy as np 6 | os.environ['CUDA_VISIBLE_DEVICES']="1" 7 | 8 | bart = BARTModel.from_pretrained('checkpoint-simile/',checkpoint_file='checkpoint_best.pt',data_name_or_path='simile') 9 | #If you want to use pretrained BART model use this 10 | # bart = BARTModel.from_pretrained('bart.large', checkpoint_file='model.pt',task='translation',data_name_or_path='your data') 11 | # https://github.com/pytorch/fairseq/issues/1944 make changes in hub_interface.py as mentioned in this issue 12 | 13 | 14 | bart.cuda() 15 | bart.eval() 16 | 17 | np.random.seed(4) 18 | torch.manual_seed(4) 19 | 20 | count = 1 21 | bsz = 1 22 | t = 0.7 23 | for val in [5]: 24 | with open('literal.txt') as source, open('simile.hypo', 'w') as fout: 25 | sline = source.readline().strip() 26 | slines = [sline] 27 | for sline in source: 28 | if count % bsz == 0: 29 | with torch.no_grad(): 30 | hypotheses_batch = bart.sample(slines, sampling=True, sampling_topk=val ,temperature=t ,lenpen=2.0, max_len_b=30, min_len=7, no_repeat_ngram_size=3) 31 | for hypothesis in hypotheses_batch: 32 | fout.write(hypothesis.replace('\n','') + '\n') 33 | fout.flush() 34 | slines = [] 35 | 36 | slines.append(sline.strip()) 37 | count += 1 38 | if slines != []: 39 | hypotheses_batch = bart.sample(slines, sampling=True, sampling_topk=val ,temperature=t ,lenpen=2.0, max_len_b=30, min_len=7, no_repeat_ngram_size=3) 40 | for hypothesis in hypotheses_batch: 41 | fout.write(hypothesis.replace('\n','') + '\n') 42 | fout.flush() 43 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import functools 7 | 8 | from fairseq.hub_utils import BPEHubInterface as bpe # noqa 9 | from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa 10 | from fairseq.models import MODEL_REGISTRY 11 | 12 | 13 | dependencies = [ 14 | 'numpy', 15 | 'regex', 16 | 'requests', 17 | 'torch', 18 | ] 19 | 20 | 21 | # torch.hub doesn't build Cython components, so if they are not found then try 22 | # to build them here 23 | try: 24 | import fairseq.data.token_block_utils_fast 25 | except (ImportError, ModuleNotFoundError): 26 | try: 27 | import cython 28 | import os 29 | from setuptools import sandbox 30 | sandbox.run_setup( 31 | os.path.join(os.path.dirname(__file__), 'setup.py'), 32 | ['build_ext', '--inplace'], 33 | ) 34 | except (ImportError, ModuleNotFoundError): 35 | print( 36 | 'Unable to build Cython components. Please make sure Cython is ' 37 | 'installed if the torch.hub model you are loading depends on it.' 38 | ) 39 | 40 | 41 | for _model_type, _cls in MODEL_REGISTRY.items(): 42 | for model_name in _cls.hub_models().keys(): 43 | globals()[model_name] = functools.partial( 44 | _cls.from_pretrained, 45 | model_name, 46 | ) 47 | # to simplify the interface we only expose named models 48 | # globals()[_model_type] = _cls.from_pretrained 49 | -------------------------------------------------------------------------------- /preprocess.sh: -------------------------------------------------------------------------------- 1 | fairseq-preprocess \ 2 | --source-lang "source" \ 3 | --target-lang "target" \ 4 | --trainpref "simile/train.bpe" \ 5 | --validpref "simile/val.bpe" \ 6 | --destdir "simile/" \ 7 | --workers 60 \ 8 | --srcdict dict.txt \ 9 | --tgtdict dict.txt; 10 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compare_namespaces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script to compare two argparse.Namespace objects.""" 3 | 4 | from argparse import Namespace # noqa 5 | 6 | 7 | def main(): 8 | 9 | ns1 = eval(input('Namespace 1: ')) 10 | ns2 = eval(input('Namespace 2: ')) 11 | 12 | def keys(ns): 13 | ks = set() 14 | for k in dir(ns): 15 | if not k.startswith('_'): 16 | ks.add(k) 17 | return ks 18 | 19 | k1 = keys(ns1) 20 | k2 = keys(ns2) 21 | 22 | def print_keys(ks, ns1, ns2=None): 23 | for k in ks: 24 | if ns2 is None: 25 | print('{}\t{}'.format(k, getattr(ns1, k, None))) 26 | else: 27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None))) 28 | 29 | print('Keys unique to namespace 1:') 30 | print_keys(k1 - k2, ns1) 31 | print() 32 | 33 | print('Keys unique to namespace 2:') 34 | print_keys(k2 - k1, ns2) 35 | print() 36 | 37 | print('Overlapping keys with different values:') 38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')] 39 | print_keys(ks, ns1, ns2) 40 | print() 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) Facebook, Inc. and its affiliates. 2 | -- 3 | -- This source code is licensed under the MIT license found in the 4 | -- LICENSE file in the root directory of this source tree. 5 | -- 6 | -- Usage: convert_dictionary.lua 7 | require 'fairseq' 8 | require 'torch' 9 | require 'paths' 10 | 11 | if #arg < 1 then 12 | print('usage: convert_dictionary.lua ') 13 | os.exit(1) 14 | end 15 | if not paths.filep(arg[1]) then 16 | print('error: file does not exit: ' .. arg[1]) 17 | os.exit(1) 18 | end 19 | 20 | dict = torch.load(arg[1]) 21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 22 | assert(dst:match('.txt$')) 23 | 24 | f = io.open(dst, 'w') 25 | for idx, symbol in ipairs(dict.index_to_symbol) do 26 | if idx > dict.cutoff then 27 | break 28 | end 29 | f:write(symbol) 30 | f:write(' ') 31 | f:write(dict.index_to_freq[idx]) 32 | f:write('\n') 33 | end 34 | f:close() 35 | -------------------------------------------------------------------------------- /scripts/count_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Count the number of documents and average number of lines and tokens per 8 | document in a large file. Documents should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import gzip 13 | import sys 14 | 15 | import numpy as np 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('input') 21 | parser.add_argument('--gzip', action='store_true') 22 | args = parser.parse_args() 23 | 24 | def gopen(): 25 | if args.gzip: 26 | return gzip.open(args.input, 'r') 27 | else: 28 | return open(args.input, 'r', encoding='utf-8') 29 | 30 | num_lines = [] 31 | num_toks = [] 32 | with gopen() as h: 33 | num_docs = 1 34 | num_lines_in_doc = 0 35 | num_toks_in_doc = 0 36 | for i, line in enumerate(h): 37 | if len(line.strip()) == 0: # empty line indicates new document 38 | num_docs += 1 39 | num_lines.append(num_lines_in_doc) 40 | num_toks.append(num_toks_in_doc) 41 | num_lines_in_doc = 0 42 | num_toks_in_doc = 0 43 | else: 44 | num_lines_in_doc += 1 45 | num_toks_in_doc += len(line.rstrip().split()) 46 | if i % 1000000 == 0: 47 | print(i, file=sys.stderr, end="", flush=True) 48 | elif i % 100000 == 0: 49 | print(".", file=sys.stderr, end="", flush=True) 50 | print(file=sys.stderr, flush=True) 51 | 52 | print("found {} docs".format(num_docs)) 53 | print("average num lines per doc: {}".format(np.mean(num_lines))) 54 | print("average num toks per doc: {}".format(np.mean(num_toks))) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | from fairseq.data import data_utils, Dictionary, indexed_dataset 10 | 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser( 14 | description='writes text from binarized file to stdout') 15 | # fmt: off 16 | parser.add_argument('--dataset-impl', help='dataset implementation', 17 | choices=indexed_dataset.get_available_dataset_impl()) 18 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) 19 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 20 | # fmt: on 21 | 22 | return parser 23 | 24 | 25 | def main(): 26 | parser = get_parser() 27 | args = parser.parse_args() 28 | 29 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None 30 | dataset = data_utils.load_indexed_dataset( 31 | args.input, 32 | dictionary, 33 | dataset_impl=args.dataset_impl, 34 | default='lazy', 35 | ) 36 | 37 | for tensor_line in dataset: 38 | if dictionary is None: 39 | line = ' '.join([str(int(x)) for x in tensor_line]) 40 | else: 41 | line = dictionary.string(tensor_line) 42 | 43 | print(line) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/sacrebleu_pregen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | echo 'Cloning Moses github repository (for tokenization scripts)...' 15 | git clone https://github.com/moses-smt/mosesdecoder.git 16 | 17 | SCRIPTS=mosesdecoder/scripts 18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | perl $DETOKENIZER -l $TGTLANG \ 25 | | sed "s/ - /-/g" \ 26 | > $GEN.sorted.detok 27 | 28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 29 | -------------------------------------------------------------------------------- /scripts/shard_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Split a large file into shards while respecting document boundaries. Documents 8 | should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import contextlib 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('input') 18 | parser.add_argument('--num-shards', type=int) 19 | args = parser.parse_args() 20 | 21 | assert args.num_shards is not None and args.num_shards > 1 22 | 23 | with open(args.input, 'r', encoding='utf-8') as h: 24 | with contextlib.ExitStack() as stack: 25 | outputs = [ 26 | stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8")) 27 | for i in range(args.num_shards) 28 | ] 29 | 30 | doc = [] 31 | first_doc = [True]*args.num_shards 32 | def output_doc(i): 33 | if not first_doc[i]: 34 | outputs[i].write("\n") 35 | first_doc[i] = False 36 | for line in doc: 37 | outputs[i].write(line) 38 | doc.clear() 39 | 40 | num_docs = 0 41 | for line in h: 42 | if line.strip() == "": # empty line indicates new document 43 | output_doc(num_docs % args.num_shards) 44 | num_docs += 1 45 | else: 46 | doc.append(line) 47 | output_doc(num_docs % args.num_shards) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /scripts/split_train_valid_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Split a large file into a train and valid set while respecting document 8 | boundaries. Documents should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import random 13 | import sys 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('input') 19 | parser.add_argument('sample_output', help='train output file') 20 | parser.add_argument('remainder_output', help='valid output file') 21 | parser.add_argument('-k', type=int, help="remainder size") 22 | parser.add_argument('--lines', action='store_true', 23 | help='split lines instead of docs') 24 | args = parser.parse_args() 25 | 26 | assert args.k is not None 27 | 28 | sample = [] 29 | remainder = [] 30 | num_docs = [0] 31 | 32 | def update_sample(doc): 33 | if len(sample) < args.k: 34 | sample.append(doc.copy()) 35 | else: 36 | i = num_docs[0] 37 | j = random.randrange(i + 1) 38 | if j < args.k: 39 | remainder.append(sample[j]) 40 | sample[j] = doc.copy() 41 | else: 42 | remainder.append(doc.copy()) 43 | num_docs[0] += 1 44 | doc.clear() 45 | 46 | with open(args.input, 'r', encoding='utf-8') as h: 47 | doc = [] 48 | for i, line in enumerate(h): 49 | if line.strip() == "": # empty line indicates new document 50 | update_sample(doc) 51 | else: 52 | doc.append(line) 53 | if args.lines: 54 | update_sample(doc) 55 | if i % 1000000 == 0: 56 | print(i, file=sys.stderr, end="", flush=True) 57 | elif i % 100000 == 0: 58 | print(".", file=sys.stderr, end="", flush=True) 59 | if len(doc) > 0: 60 | update_sample(doc) 61 | print(file=sys.stderr, flush=True) 62 | 63 | assert len(sample) == args.k 64 | 65 | with open(args.sample_output, 'w', encoding='utf-8') as out: 66 | first = True 67 | for doc in sample: 68 | if not first and not args.lines: 69 | out.write("\n") 70 | first = False 71 | for line in doc: 72 | out.write(line) 73 | 74 | with open(args.remainder_output, 'w', encoding='utf-8') as out: 75 | first = True 76 | for doc in remainder: 77 | if not first and not args.lines: 78 | out.write("\n") 79 | first = False 80 | for line in doc: 81 | out.write(line) 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", required=True, 18 | help="sentencepiece model to use for decoding") 19 | parser.add_argument("--input", required=True, help="input file to decode") 20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 21 | args = parser.parse_args() 22 | 23 | sp = spm.SentencePieceProcessor() 24 | sp.Load(args.model) 25 | 26 | if args.input_format == "piece": 27 | def decode(l): 28 | return "".join(sp.DecodePieces(l)) 29 | elif args.input_format == "id": 30 | def decode(l): 31 | return "".join(sp.DecodeIds(l)) 32 | else: 33 | raise NotImplementedError 34 | 35 | def tok2int(tok): 36 | # remap reference-side (represented as <>) to 0 37 | return int(tok) if tok != "<>" else 0 38 | 39 | with open(args.input, "r", encoding="utf-8") as h: 40 | for line in h: 41 | if args.input_format == "id": 42 | print(decode(list(map(tok2int, line.rstrip().split())))) 43 | elif args.input_format == "piece": 44 | print(decode(line.rstrip().split())) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import sys 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | if __name__ == "__main__": 16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) 17 | -------------------------------------------------------------------------------- /scripts/wav2vec_manifest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Data pre-processing: build vocabularies and binarize training data. 8 | """ 9 | 10 | import argparse 11 | import glob 12 | import os 13 | import soundfile 14 | import random 15 | 16 | 17 | def get_parser(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('root', metavar='DIR', help='root directory containing flac files to index') 20 | parser.add_argument('--valid-percent', default=0.01, type=float, metavar='D', 21 | help='percentage of data to use as validation set (between 0 and 1)') 22 | parser.add_argument('--dest', default='.', type=str, metavar='DIR', help='output directory') 23 | parser.add_argument('--ext', default='flac', type=str, metavar='EXT', help='extension to look for') 24 | parser.add_argument('--seed', default=42, type=int, metavar='N', help='random seed') 25 | parser.add_argument('--path-must-contain', default=None, type=str, metavar='FRAG', 26 | help='if set, path must contain this substring for a file to be included in the manifest') 27 | return parser 28 | 29 | 30 | def main(args): 31 | assert args.valid_percent >= 0 and args.valid_percent <= 1. 32 | 33 | dir_path = os.path.realpath(args.root) 34 | search_path = os.path.join(dir_path, '**/*.' + args.ext) 35 | rand = random.Random(args.seed) 36 | 37 | with open(os.path.join(args.dest, 'train.tsv'), 'w') as train_f, open( 38 | os.path.join(args.dest, 'valid.tsv'), 'w') as valid_f: 39 | print(dir_path, file=train_f) 40 | print(dir_path, file=valid_f) 41 | 42 | for fname in glob.iglob(search_path, recursive=True): 43 | file_path = os.path.realpath(fname) 44 | 45 | if args.path_must_contain and args.path_must_contain not in file_path: 46 | continue 47 | 48 | frames = soundfile.info(fname).frames 49 | dest = train_f if rand.random() > args.valid_percent else valid_f 50 | print('{}\t{}'.format(os.path.relpath(file_path, dir_path), frames), file=dest) 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = get_parser() 55 | args = parser.parse_args() 56 | main(args) 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/tests/__init__.py -------------------------------------------------------------------------------- /tests/speech_recognition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuhinjubcse/SimileGeneration-EMNLP2020/cb40a16409912e1eefb2204e6b1ab953fbe8bdc6/tests/speech_recognition/__init__.py -------------------------------------------------------------------------------- /tests/speech_recognition/test_collaters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | import numpy as np 10 | import torch 11 | from examples.speech_recognition.data.collaters import Seq2SeqCollater 12 | 13 | 14 | class TestSeq2SeqCollator(unittest.TestCase): 15 | def test_collate(self): 16 | 17 | eos_idx = 1 18 | pad_idx = 0 19 | collater = Seq2SeqCollater( 20 | feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx 21 | ) 22 | 23 | # 2 frames in the first sample and 3 frames in the second one 24 | frames1 = np.array([[7, 8], [9, 10]]) 25 | frames2 = np.array([[1, 2], [3, 4], [5, 6]]) 26 | target1 = np.array([4, 2, 3, eos_idx]) 27 | target2 = np.array([3, 2, eos_idx]) 28 | sample1 = {"id": 0, "data": [frames1, target1]} 29 | sample2 = {"id": 1, "data": [frames2, target2]} 30 | batch = collater.collate([sample1, sample2]) 31 | 32 | # collate sort inputs by frame's length before creating the batch 33 | self.assertTensorEqual(batch["id"], torch.tensor([1, 0])) 34 | self.assertEqual(batch["ntokens"], 7) 35 | self.assertTensorEqual( 36 | batch["net_input"]["src_tokens"], 37 | torch.tensor( 38 | [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]] 39 | ), 40 | ) 41 | self.assertTensorEqual( 42 | batch["net_input"]["prev_output_tokens"], 43 | torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]), 44 | ) 45 | self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2])) 46 | self.assertTensorEqual( 47 | batch["target"], 48 | torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]), 49 | ) 50 | self.assertEqual(batch["nsentences"], 2) 51 | 52 | def assertTensorEqual(self, t1, t2): 53 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 54 | self.assertEqual(t1.ne(t2).long().sum(), 0) 55 | 56 | 57 | if __name__ == "__main__": 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /tests/speech_recognition/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion 8 | from .asr_test_base import CrossEntropyCriterionTestBase 9 | 10 | 11 | class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase): 12 | def setUp(self): 13 | self.criterion_cls = CrossEntropyWithAccCriterion 14 | super().setUp() 15 | 16 | def test_cross_entropy_all_correct(self): 17 | sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False) 18 | loss, sample_size, logging_output = self.criterion( 19 | self.model, sample, "sum", log_probs=True 20 | ) 21 | assert logging_output["correct"] == 20 22 | assert logging_output["total"] == 20 23 | assert logging_output["sample_size"] == 20 24 | assert logging_output["ntokens"] == 20 25 | 26 | def test_cross_entropy_all_wrong(self): 27 | sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False) 28 | loss, sample_size, logging_output = self.criterion( 29 | self.model, sample, "sum", log_probs=True 30 | ) 31 | assert logging_output["correct"] == 0 32 | assert logging_output["total"] == 20 33 | assert logging_output["sample_size"] == 20 34 | assert logging_output["ntokens"] == 20 35 | -------------------------------------------------------------------------------- /tests/test_character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import unittest 8 | 9 | from fairseq.data import Dictionary 10 | from fairseq.modules import CharacterTokenEmbedder 11 | 12 | 13 | class TestCharacterTokenEmbedder(unittest.TestCase): 14 | def test_character_token_embedder(self): 15 | vocab = Dictionary() 16 | vocab.add_symbol('hello') 17 | vocab.add_symbol('there') 18 | 19 | embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5, 2) 20 | 21 | test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']] 22 | max_len = max(len(s) for s in test_sents) 23 | input = torch.LongTensor(len(test_sents), max_len + 2).fill_(vocab.pad()) 24 | for i in range(len(test_sents)): 25 | input[i][0] = vocab.eos() 26 | for j in range(len(test_sents[i])): 27 | input[i][j + 1] = vocab.index(test_sents[i][j]) 28 | input[i][j + 2] = vocab.eos() 29 | embs = embedder(input) 30 | 31 | assert embs.size() == (len(test_sents), max_len + 2, 5) 32 | self.assertAlmostEqual(embs[0][0], embs[1][0]) 33 | self.assertAlmostEqual(embs[0][0], embs[0][-1]) 34 | self.assertAlmostEqual(embs[0][1], embs[2][1]) 35 | self.assertAlmostEqual(embs[0][3], embs[1][1]) 36 | 37 | embs.sum().backward() 38 | assert embedder.char_embeddings.weight.grad is not None 39 | 40 | def assertAlmostEqual(self, t1, t2): 41 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 42 | self.assertLess((t1 - t2).abs().max(), 1e-6) 43 | 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | import torch 9 | from fairseq.data import LanguagePairDataset, TokenBlockDataset 10 | from fairseq.data.concat_dataset import ConcatDataset 11 | from tests.test_train import mock_dict 12 | 13 | 14 | class TestConcatDataset(unittest.TestCase): 15 | def setUp(self): 16 | d = mock_dict() 17 | tokens_1 = torch.LongTensor([1]).view(1, -1) 18 | tokens_ds1 = TokenBlockDataset( 19 | tokens_1, 20 | sizes=[tokens_1.size(-1)], 21 | block_size=1, 22 | pad=0, 23 | eos=1, 24 | include_targets=False, 25 | ) 26 | self.dataset_1 = LanguagePairDataset( 27 | tokens_ds1, tokens_ds1.sizes, d, shuffle=False 28 | ) 29 | tokens_2 = torch.LongTensor([2]).view(1, -1) 30 | tokens_ds2 = TokenBlockDataset( 31 | tokens_2, 32 | sizes=[tokens_2.size(-1)], 33 | block_size=1, 34 | pad=0, 35 | eos=1, 36 | include_targets=False, 37 | ) 38 | self.dataset_2 = LanguagePairDataset( 39 | tokens_ds2, tokens_ds2.sizes, d, shuffle=False 40 | ) 41 | 42 | def test_concat_dataset_basics(self): 43 | d = ConcatDataset( 44 | [self.dataset_1, self.dataset_2] 45 | ) 46 | assert(len(d) == 2) 47 | assert(d[0]['source'][0] == 1) 48 | assert(d[1]['source'][0] == 2) 49 | 50 | d = ConcatDataset( 51 | [self.dataset_1, self.dataset_2], sample_ratios=[1, 2] 52 | ) 53 | assert(len(d) == 3) 54 | assert(d[0]['source'][0] == 1) 55 | assert(d[1]['source'][0] == 2) 56 | assert(d[2]['source'][0] == 2) 57 | 58 | d = ConcatDataset( 59 | [self.dataset_1, self.dataset_2], sample_ratios=[2, 1] 60 | ) 61 | assert(len(d) == 3) 62 | assert(d[0]['source'][0] == 1) 63 | assert(d[1]['source'][0] == 1) 64 | assert(d[2]['source'][0] == 2) 65 | -------------------------------------------------------------------------------- /tests/test_convtbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import unittest 8 | from fairseq.modules import ConvTBC 9 | import torch.nn as nn 10 | 11 | 12 | class TestConvTBC(unittest.TestCase): 13 | 14 | def test_convtbc(self): 15 | # ksz, in_channels, out_channels 16 | conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1) 17 | # out_channels, in_channels, ksz 18 | conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1) 19 | 20 | conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) 21 | conv_tbc.bias.data.copy_(conv1d.bias.data) 22 | 23 | input_tbc = torch.randn(7, 2, 4, requires_grad=True) 24 | input1d = input_tbc.data.transpose(0, 1).transpose(1, 2) 25 | input1d.requires_grad = True 26 | 27 | output_tbc = conv_tbc(input_tbc) 28 | output1d = conv1d(input1d) 29 | 30 | self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) 31 | 32 | grad_tbc = torch.randn(output_tbc.size()) 33 | grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() 34 | 35 | output_tbc.backward(grad_tbc) 36 | output1d.backward(grad1d) 37 | 38 | self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data) 39 | self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data) 40 | self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data) 41 | 42 | def assertAlmostEqual(self, t1, t2): 43 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 44 | self.assertLess((t1 - t2).abs().max(), 1e-4) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import tempfile 7 | import unittest 8 | 9 | import torch 10 | 11 | from fairseq.data import Dictionary 12 | 13 | 14 | class TestDictionary(unittest.TestCase): 15 | 16 | def test_finalize(self): 17 | txt = [ 18 | 'A B C D', 19 | 'B C D', 20 | 'C D', 21 | 'D', 22 | ] 23 | ref_ids1 = list(map(torch.IntTensor, [ 24 | [4, 5, 6, 7, 2], 25 | [5, 6, 7, 2], 26 | [6, 7, 2], 27 | [7, 2], 28 | ])) 29 | ref_ids2 = list(map(torch.IntTensor, [ 30 | [7, 6, 5, 4, 2], 31 | [6, 5, 4, 2], 32 | [5, 4, 2], 33 | [4, 2], 34 | ])) 35 | 36 | # build dictionary 37 | d = Dictionary() 38 | for line in txt: 39 | d.encode_line(line, add_if_not_exist=True) 40 | 41 | def get_ids(dictionary): 42 | ids = [] 43 | for line in txt: 44 | ids.append(dictionary.encode_line(line, add_if_not_exist=False)) 45 | return ids 46 | 47 | def assertMatch(ids, ref_ids): 48 | for toks, ref_toks in zip(ids, ref_ids): 49 | self.assertEqual(toks.size(), ref_toks.size()) 50 | self.assertEqual(0, (toks != ref_toks).sum().item()) 51 | 52 | ids = get_ids(d) 53 | assertMatch(ids, ref_ids1) 54 | 55 | # check finalized dictionary 56 | d.finalize() 57 | finalized_ids = get_ids(d) 58 | assertMatch(finalized_ids, ref_ids2) 59 | 60 | # write to disk and reload 61 | with tempfile.NamedTemporaryFile(mode='w') as tmp_dict: 62 | d.save(tmp_dict.name) 63 | d = Dictionary.load(tmp_dict.name) 64 | reload_ids = get_ids(d) 65 | assertMatch(reload_ids, ref_ids2) 66 | assertMatch(finalized_ids, reload_ids) 67 | 68 | 69 | if __name__ == '__main__': 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /tests/test_file_io.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | import sys 5 | import tempfile 6 | import os 7 | import shutil 8 | 9 | from typing import Optional 10 | 11 | import unittest 12 | from unittest.mock import MagicMock 13 | 14 | 15 | class TestFileIO(unittest.TestCase): 16 | 17 | _tmpdir: Optional[str] = None 18 | _tmpfile: Optional[str] = None 19 | _tmpfile_contents = "Hello, World" 20 | 21 | @classmethod 22 | def setUpClass(cls) -> None: 23 | cls._tmpdir = tempfile.mkdtemp() 24 | with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f: 25 | cls._tmpfile = f.name 26 | f.write(cls._tmpfile_contents) 27 | f.flush() 28 | 29 | @classmethod 30 | def tearDownClass(cls) -> None: 31 | # Cleanup temp working dir. 32 | if cls._tmpdir is not None: 33 | shutil.rmtree(cls._tmpdir) # type: ignore 34 | 35 | def test_file_io(self): 36 | from fairseq.file_io import PathManager 37 | with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: 38 | s = f.read() 39 | self.assertEqual(s, self._tmpfile_contents) 40 | 41 | def test_file_io_oss(self): 42 | # Mock fvcore to simulate oss environment. 43 | sys.modules['fvcore'] = MagicMock() 44 | from fairseq.file_io import PathManager 45 | with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: 46 | s = f.read() 47 | self.assertEqual(s, self._tmpfile_contents) 48 | -------------------------------------------------------------------------------- /tests/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | from fairseq.data import iterators 9 | 10 | 11 | class TestIterators(unittest.TestCase): 12 | 13 | def test_counting_iterator(self): 14 | x = list(range(10)) 15 | itr = iterators.CountingIterator(x) 16 | self.assertTrue(itr.has_next()) 17 | self.assertEqual(next(itr), 0) 18 | self.assertEqual(next(itr), 1) 19 | itr.skip(3) 20 | self.assertEqual(next(itr), 5) 21 | itr.skip(3) 22 | self.assertEqual(next(itr), 9) 23 | self.assertFalse(itr.has_next()) 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_memory_efficient_fp16.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import unittest 8 | 9 | import torch 10 | 11 | from fairseq.optim.adam import FairseqAdam 12 | from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer 13 | 14 | 15 | @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') 16 | class TestMemoryEfficientFP16(unittest.TestCase): 17 | 18 | def test_load_state_dict(self): 19 | # define simple FP16 model 20 | model = torch.nn.Linear(5, 5).cuda().half() 21 | params = list(model.parameters()) 22 | 23 | # initialize memory efficient FP16 optimizer 24 | optimizer = FairseqAdam( 25 | argparse.Namespace( 26 | lr=[0.00001], 27 | adam_betas='(0.9, 0.999)', 28 | adam_eps=1e-8, 29 | weight_decay=0.0, 30 | ), 31 | params, 32 | ) 33 | me_optimizer = MemoryEfficientFP16Optimizer( 34 | argparse.Namespace( 35 | fp16_init_scale=1, 36 | fp16_scale_window=1, 37 | fp16_scale_tolerance=1, 38 | threshold_loss_scale=1, 39 | min_loss_scale=1e-4, 40 | ), 41 | params, 42 | optimizer, 43 | ) 44 | 45 | # optimizer state is created in the first step 46 | loss = model(torch.rand(5).cuda().half()).sum() 47 | me_optimizer.backward(loss) 48 | me_optimizer.step() 49 | 50 | # reload state 51 | state = me_optimizer.state_dict() 52 | me_optimizer.load_state_dict(state) 53 | for k, v in me_optimizer.optimizer.state.items(): 54 | self.assertTrue(k.dtype == torch.float16) 55 | for v_i in v.values(): 56 | if torch.is_tensor(v_i): 57 | self.assertTrue(v_i.dtype == torch.float32) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /tests/test_multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import unittest 8 | from fairseq.modules.multihead_attention import MultiheadAttention 9 | 10 | 11 | class TestMultiheadAttention(unittest.TestCase): 12 | def test_append_prev_key_padding_mask(self): 13 | bsz = 1 14 | src_len = 4 15 | 16 | cases = [ 17 | # no padding mask 18 | (None, None, None), 19 | # current padding mask only 20 | ( 21 | torch.tensor([[1]]).bool(), 22 | None, 23 | torch.tensor([[0, 0, 0, 1]]).bool(), 24 | ), 25 | # previous padding mask only 26 | ( 27 | None, 28 | torch.tensor([[0, 1, 0]]).bool(), 29 | torch.tensor([[0, 1, 0, 0]]).bool(), 30 | ), 31 | # both padding masks 32 | ( 33 | torch.tensor([[1]]).bool(), 34 | torch.tensor([[0, 1, 0]]).bool(), 35 | torch.tensor([[0, 1, 0, 1]]).bool(), 36 | ), 37 | ] 38 | for c in cases: 39 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( 40 | c[0], 41 | c[1], 42 | batch_size=bsz, 43 | src_len=src_len, 44 | static_kv=False, 45 | ) 46 | 47 | if key_padding_mask is not None: 48 | self.assertTrue( 49 | torch.all(torch.eq(key_padding_mask, c[2])), 50 | f'Unexpected resultant key padding mask: {key_padding_mask}' 51 | f' given current: {c[0]} and previous: {c[1]}', 52 | ) 53 | self.assertEqual(key_padding_mask.size(0), bsz) 54 | self.assertEqual(key_padding_mask.size(1), src_len) 55 | else: 56 | self.assertIsNone(c[2]) 57 | 58 | 59 | if __name__ == '__main__': 60 | unittest.main() 61 | -------------------------------------------------------------------------------- /tests/test_sparse_multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import unittest 8 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention 9 | 10 | 11 | class TestSparseMultiheadAttention(unittest.TestCase): 12 | def test_sparse_multihead_attention(self): 13 | attn_weights = torch.randn(1, 8, 8) 14 | bidirectional_sparse_mask = torch.tensor([ 15 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], 16 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], 17 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], 18 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0], 19 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], 20 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], 21 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], 22 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0] 23 | ]) 24 | 25 | bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True) 26 | bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8) 27 | torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)) 28 | 29 | sparse_mask = torch.tensor([ 30 | [0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), 31 | float('-inf'), float('-inf')], 32 | [0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], 33 | [0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')], 34 | [0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')], 35 | [0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')], 36 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')], 37 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')], 38 | [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0], 39 | ]) 40 | 41 | attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False) 42 | attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8) 43 | 44 | torch.all(torch.eq(attention_sparse_mask, sparse_mask)) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | import torch 9 | 10 | from fairseq import utils 11 | 12 | 13 | class TestUtils(unittest.TestCase): 14 | 15 | def test_convert_padding_direction(self): 16 | pad = 1 17 | left_pad = torch.LongTensor([ 18 | [2, 3, 4, 5, 6], 19 | [1, 7, 8, 9, 10], 20 | [1, 1, 1, 11, 12], 21 | ]) 22 | right_pad = torch.LongTensor([ 23 | [2, 3, 4, 5, 6], 24 | [7, 8, 9, 10, 1], 25 | [11, 12, 1, 1, 1], 26 | ]) 27 | 28 | self.assertAlmostEqual( 29 | right_pad, 30 | utils.convert_padding_direction( 31 | left_pad, 32 | pad, 33 | left_to_right=True, 34 | ), 35 | ) 36 | self.assertAlmostEqual( 37 | left_pad, 38 | utils.convert_padding_direction( 39 | right_pad, 40 | pad, 41 | right_to_left=True, 42 | ), 43 | ) 44 | 45 | def test_make_positions(self): 46 | pad = 1 47 | left_pad_input = torch.LongTensor([ 48 | [9, 9, 9, 9, 9], 49 | [1, 9, 9, 9, 9], 50 | [1, 1, 1, 9, 9], 51 | ]) 52 | left_pad_output = torch.LongTensor([ 53 | [2, 3, 4, 5, 6], 54 | [1, 2, 3, 4, 5], 55 | [1, 1, 1, 2, 3], 56 | ]) 57 | right_pad_input = torch.LongTensor([ 58 | [9, 9, 9, 9, 9], 59 | [9, 9, 9, 9, 1], 60 | [9, 9, 1, 1, 1], 61 | ]) 62 | right_pad_output = torch.LongTensor([ 63 | [2, 3, 4, 5, 6], 64 | [2, 3, 4, 5, 1], 65 | [2, 3, 1, 1, 1], 66 | ]) 67 | 68 | self.assertAlmostEqual( 69 | left_pad_output, 70 | utils.make_positions(left_pad_input, pad), 71 | ) 72 | self.assertAlmostEqual( 73 | right_pad_output, 74 | utils.make_positions(right_pad_input, pad), 75 | ) 76 | 77 | def assertAlmostEqual(self, t1, t2): 78 | self.assertEqual(t1.size(), t2.size(), "size mismatch") 79 | self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | --------------------------------------------------------------------------------