├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── __pycache__ └── preprocess.cpython-36.pyc ├── build ├── lib.linux-x86_64-3.6 │ └── fairseq │ │ ├── data │ │ ├── data_utils_fast.cpython-36m-x86_64-linux-gnu.so │ │ └── token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so │ │ └── libbleu.cpython-36m-x86_64-linux-gnu.so └── temp.linux-x86_64-3.6 │ └── fairseq │ ├── clib │ └── libbleu │ │ ├── libbleu.o │ │ └── module.o │ └── data │ ├── data_utils_fast.o │ └── token_block_utils_fast.o ├── docs ├── Makefile ├── _static │ └── theme_overrides.css ├── command_line_tools.rst ├── conf.py ├── criterions.rst ├── data.rst ├── docutils.conf ├── getting_started.rst ├── index.rst ├── lr_scheduler.rst ├── make.bat ├── models.rst ├── modules.rst ├── optim.rst ├── overview.rst ├── requirements.txt ├── tasks.rst ├── tutorial_classifying_names.rst └── tutorial_simple_lstm.rst ├── eval_lm.py ├── examples ├── .DS_Store ├── .gitignore ├── KEPLER │ ├── GLUE │ │ ├── CoLA.sh │ │ ├── MNLI.sh │ │ ├── MRPC.sh │ │ ├── QNLI.sh │ │ ├── QQP.sh │ │ ├── RTE.sh │ │ ├── SST-2.sh │ │ └── STS-B.sh │ ├── KE │ │ ├── evaluate_transe_inductive.py │ │ ├── evaluate_transe_transductive.py │ │ └── generate_embeddings.py │ ├── OpenEntity │ │ ├── pytorch_transformers │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── configuration_auto.py │ │ │ ├── configuration_bert.py │ │ │ ├── configuration_distilbert.py │ │ │ ├── configuration_gpt2.py │ │ │ ├── configuration_openai.py │ │ │ ├── configuration_roberta.py │ │ │ ├── configuration_transfo_xl.py │ │ │ ├── configuration_utils.py │ │ │ ├── configuration_xlm.py │ │ │ ├── configuration_xlnet.py │ │ │ ├── convert_gpt2_checkpoint_to_pytorch.py │ │ │ ├── convert_openai_checkpoint_to_pytorch.py │ │ │ ├── convert_pytorch_checkpoint_to_tf.py │ │ │ ├── convert_roberta_checkpoint_to_pytorch.py │ │ │ ├── convert_tf_checkpoint_to_pytorch.py │ │ │ ├── convert_transfo_xl_checkpoint_to_pytorch.py │ │ │ ├── convert_xlm_checkpoint_to_pytorch.py │ │ │ ├── convert_xlnet_checkpoint_to_pytorch.py │ │ │ ├── file_utils.py │ │ │ ├── modeling_auto.py │ │ │ ├── modeling_bert.py │ │ │ ├── modeling_distilbert.py │ │ │ ├── modeling_gpt2.py │ │ │ ├── modeling_openai.py │ │ │ ├── modeling_roberta.py │ │ │ ├── modeling_transfo_xl.py │ │ │ ├── modeling_transfo_xl_utilities.py │ │ │ ├── modeling_utils.py │ │ │ ├── modeling_xlm.py │ │ │ ├── modeling_xlnet.py │ │ │ ├── optimization.py │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_common_test.py │ │ │ │ ├── conftest.py │ │ │ │ ├── fixtures │ │ │ │ │ ├── input.txt │ │ │ │ │ ├── sample_text.txt │ │ │ │ │ └── test_sentencepiece.model │ │ │ │ ├── modeling_auto_test.py │ │ │ │ ├── modeling_bert_test.py │ │ │ │ ├── modeling_common_test.py │ │ │ │ ├── modeling_distilbert_test.py │ │ │ │ ├── modeling_gpt2_test.py │ │ │ │ ├── modeling_openai_test.py │ │ │ │ ├── modeling_roberta_test.py │ │ │ │ ├── modeling_transfo_xl_test.py │ │ │ │ ├── modeling_xlm_test.py │ │ │ │ ├── modeling_xlnet_test.py │ │ │ │ ├── optimization_test.py │ │ │ │ ├── tokenization_auto_test.py │ │ │ │ ├── tokenization_bert_test.py │ │ │ │ ├── tokenization_dilbert_test.py │ │ │ │ ├── tokenization_gpt2_test.py │ │ │ │ ├── tokenization_openai_test.py │ │ │ │ ├── tokenization_roberta_test.py │ │ │ │ ├── tokenization_tests_commons.py │ │ │ │ ├── tokenization_transfo_xl_test.py │ │ │ │ ├── tokenization_utils_test.py │ │ │ │ ├── tokenization_xlm_test.py │ │ │ │ └── tokenization_xlnet_test.py │ │ │ ├── tokenization_auto.py │ │ │ ├── tokenization_bert.py │ │ │ ├── tokenization_distilbert.py │ │ │ ├── tokenization_gpt2.py │ │ │ ├── tokenization_openai.py │ │ │ ├── tokenization_roberta.py │ │ │ ├── tokenization_transfo_xl.py │ │ │ ├── tokenization_utils.py │ │ │ ├── tokenization_xlm.py │ │ │ └── tokenization_xlnet.py │ │ ├── run_openentity.sh │ │ ├── run_typing.py │ │ └── utils_glue.py │ ├── Pretrain │ │ ├── KGpreprocess.py │ │ ├── convert.py │ │ └── splitDump.py │ └── TACRED │ │ └── TACRED.sh ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── backtranslation │ └── README.md ├── conv_seq2seq │ └── README.md ├── cross_lingual_language_model │ └── README.md ├── language_model │ ├── README.md │ ├── conv_lm │ │ └── README.md │ └── transformer_lm │ │ └── README.md ├── noisychannel │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── rerank_options.cpython-36.pyc │ ├── rerank.py │ ├── rerank_generate.py │ ├── rerank_options.py │ ├── rerank_score_bw.py │ ├── rerank_score_lm.py │ ├── rerank_tune.py │ └── rerank_utils.py ├── pay_less_attention_paper │ └── README.md ├── roberta │ ├── README.custom_classification.md │ ├── README.glue.md │ ├── README.md │ ├── README.pretraining.md │ ├── README.race.md │ ├── __pycache__ │ │ └── multiprocessing_bpe_encoder.cpython-36.pyc │ ├── commonsense_qa │ │ ├── README.md │ │ ├── __init__.py │ │ ├── commonsense_qa_task.py │ │ └── download_cqa_data.sh │ ├── multiprocessing_bpe_encoder.py │ ├── preprocess_GLUE_tasks.sh │ ├── preprocess_RACE.py │ ├── preprocess_RACE.sh │ └── wsc │ │ ├── README.md │ │ ├── __init__.py │ │ ├── wsc_criterion.py │ │ ├── wsc_task.py │ │ └── wsc_utils.py ├── scaling_nmt │ └── README.md ├── speech_recognition │ ├── README.md │ ├── __init__.py │ ├── criterions │ │ ├── __init__.py │ │ └── cross_entropy_acc.py │ ├── data │ │ ├── __init__.py │ │ ├── asr_dataset.py │ │ ├── collaters.py │ │ └── data_utils.py │ ├── datasets │ │ ├── asr_prep_json.py │ │ └── prepare-librispeech.sh │ ├── infer.py │ ├── models │ │ ├── __init__.py │ │ └── vggtransformer.py │ └── tasks │ │ ├── __init__.py │ │ └── speech_recognition.py ├── stories │ └── README.md ├── translation │ ├── README.md │ ├── prepare-iwslt14.sh │ ├── prepare-iwslt17-multilingual.sh │ ├── prepare-wmt14en2de.sh │ └── prepare-wmt14en2fr.sh ├── translation_moe │ ├── README.md │ └── score.py ├── wav2vec │ └── README.md └── wmt19 │ └── README.md ├── fairseq.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── entry_points.txt ├── not-zip-safe ├── requires.txt └── top_level.txt ├── fairseq.gif ├── fairseq ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── binarizer.cpython-36.pyc │ ├── checkpoint_utils.cpython-36.pyc │ ├── distributed_utils.cpython-36.pyc │ ├── file_utils.cpython-36.pyc │ ├── hub_utils.cpython-36.pyc │ ├── legacy_distributed_data_parallel.cpython-36.pyc │ ├── meters.cpython-36.pyc │ ├── options.cpython-36.pyc │ ├── pdb.cpython-36.pyc │ ├── registry.cpython-36.pyc │ ├── search.cpython-36.pyc │ ├── sequence_generator.cpython-36.pyc │ ├── tokenizer.cpython-36.pyc │ └── utils.cpython-36.pyc ├── binarizer.py ├── bleu.py ├── checkpoint_utils.py ├── clib │ └── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp ├── criterions │ ├── MLMetKE.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── MLMetKE.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── adaptive_loss.cpython-36.pyc │ │ ├── binary_cross_entropy.cpython-36.pyc │ │ ├── composite_loss.cpython-36.pyc │ │ ├── cross_entropy.cpython-36.pyc │ │ ├── fairseq_criterion.cpython-36.pyc │ │ ├── label_smoothed_cross_entropy.cpython-36.pyc │ │ ├── legacy_masked_lm.cpython-36.pyc │ │ ├── masked_lm.cpython-36.pyc │ │ ├── only_ke.cpython-36.pyc │ │ ├── relation_extraction.cpython-36.pyc │ │ ├── sentence_prediction.cpython-36.pyc │ │ ├── sentence_prediction_debug.cpython-36.pyc │ │ └── sentence_ranking.cpython-36.pyc │ ├── adaptive_loss.py │ ├── binary_cross_entropy.py │ ├── composite_loss.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ ├── label_smoothed_cross_entropy.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── only_ke.py │ ├── relation_extraction.py │ ├── sentence_prediction.py │ ├── sentence_prediction_debug.py │ └── sentence_ranking.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── backtranslation_dataset.cpython-36.pyc │ │ ├── base_wrapper_dataset.cpython-36.pyc │ │ ├── concat_dataset.cpython-36.pyc │ │ ├── concat_sentences_dataset.cpython-36.pyc │ │ ├── data_utils.cpython-36.pyc │ │ ├── dictionary.cpython-36.pyc │ │ ├── fairseq_dataset.cpython-36.pyc │ │ ├── fake_numel_dataset.cpython-36.pyc │ │ ├── id_dataset.cpython-36.pyc │ │ ├── indexed_dataset.cpython-36.pyc │ │ ├── iterators.cpython-36.pyc │ │ ├── ke_dataset.cpython-36.pyc │ │ ├── ke_negative_dataset.cpython-36.pyc │ │ ├── language_pair_dataset.cpython-36.pyc │ │ ├── list_dataset.cpython-36.pyc │ │ ├── lm_context_window_dataset.cpython-36.pyc │ │ ├── lru_cache_dataset.cpython-36.pyc │ │ ├── mask_tokens_dataset.cpython-36.pyc │ │ ├── monolingual_dataset.cpython-36.pyc │ │ ├── multi_corpus_sampled_dataset.cpython-36.pyc │ │ ├── nested_dictionary_dataset.cpython-36.pyc │ │ ├── noising.cpython-36.pyc │ │ ├── num_samples_dataset.cpython-36.pyc │ │ ├── numel_dataset.cpython-36.pyc │ │ ├── offset_tokens_dataset.cpython-36.pyc │ │ ├── pad_dataset.cpython-36.pyc │ │ ├── plasma_utils.cpython-36.pyc │ │ ├── prepend_dataset.cpython-36.pyc │ │ ├── prepend_token_dataset.cpython-36.pyc │ │ ├── raw_label_dataset.cpython-36.pyc │ │ ├── replace_dataset.cpython-36.pyc │ │ ├── round_robin_zip_datasets.cpython-36.pyc │ │ ├── sharded_dataset.cpython-36.pyc │ │ ├── sort_dataset.cpython-36.pyc │ │ ├── strip_token_dataset.cpython-36.pyc │ │ ├── subsample_dataset.cpython-36.pyc │ │ ├── token_block_dataset.cpython-36.pyc │ │ ├── transform_eos_dataset.cpython-36.pyc │ │ ├── transform_eos_lang_pair_dataset.cpython-36.pyc │ │ └── truncate_dataset.cpython-36.pyc │ ├── audio │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── raw_audio_dataset.cpython-36.pyc │ │ └── raw_audio_dataset.py │ ├── backtranslation_dataset.py │ ├── base_wrapper_dataset.py │ ├── concat_dataset.py │ ├── concat_sentences_dataset.py │ ├── data_utils.py │ ├── data_utils_fast.cpp │ ├── data_utils_fast.cpython-36m-x86_64-linux-gnu.so │ ├── data_utils_fast.pyx │ ├── dictionary.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── fastbpe.cpython-36.pyc │ │ │ ├── gpt2_bpe.cpython-36.pyc │ │ │ ├── gpt2_bpe_utils.cpython-36.pyc │ │ │ ├── hf_bert_bpe.cpython-36.pyc │ │ │ ├── moses_tokenizer.cpython-36.pyc │ │ │ ├── nltk_tokenizer.cpython-36.pyc │ │ │ ├── sentencepiece_bpe.cpython-36.pyc │ │ │ ├── space_tokenizer.cpython-36.pyc │ │ │ └── subword_nmt_bpe.cpython-36.pyc │ │ ├── fastbpe.py │ │ ├── gpt2_bpe.py │ │ ├── gpt2_bpe_utils.py │ │ ├── hf_bert_bpe.py │ │ ├── moses_tokenizer.py │ │ ├── nltk_tokenizer.py │ │ ├── sentencepiece_bpe.py │ │ ├── space_tokenizer.py │ │ └── subword_nmt_bpe.py │ ├── fairseq_dataset.py │ ├── fake_numel_dataset.py │ ├── id_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── ke_dataset.py │ ├── ke_negative_dataset.py │ ├── language_pair_dataset.py │ ├── legacy │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── block_pair_dataset.cpython-36.pyc │ │ │ ├── masked_lm_dataset.cpython-36.pyc │ │ │ └── masked_lm_dictionary.cpython-36.pyc │ │ ├── block_pair_dataset.py │ │ ├── masked_lm_dataset.py │ │ └── masked_lm_dictionary.py │ ├── list_dataset.py │ ├── lm_context_window_dataset.py │ ├── lru_cache_dataset.py │ ├── mask_tokens_dataset.py │ ├── monolingual_dataset.py │ ├── multi_corpus_sampled_dataset.py │ ├── nested_dictionary_dataset.py │ ├── noising.py │ ├── num_samples_dataset.py │ ├── numel_dataset.py │ ├── offset_tokens_dataset.py │ ├── pad_dataset.py │ ├── plasma_utils.py │ ├── prepend_dataset.py │ ├── prepend_token_dataset.py │ ├── raw_label_dataset.py │ ├── replace_dataset.py │ ├── round_robin_zip_datasets.py │ ├── sharded_dataset.py │ ├── sort_dataset.py │ ├── strip_token_dataset.py │ ├── subsample_dataset.py │ ├── token_block_dataset.py │ ├── token_block_utils_fast.c │ ├── token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so │ ├── token_block_utils_fast.pyx │ ├── transform_eos_dataset.py │ ├── transform_eos_lang_pair_dataset.py │ └── truncate_dataset.py ├── distributed_utils.py ├── file_utils.py ├── hub_utils.py ├── legacy_distributed_data_parallel.py ├── libbleu.cpython-36m-x86_64-linux-gnu.so ├── meters.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── composite_encoder.cpython-36.pyc │ │ ├── distributed_fairseq_model.cpython-36.pyc │ │ ├── fairseq_decoder.cpython-36.pyc │ │ ├── fairseq_encoder.cpython-36.pyc │ │ ├── fairseq_incremental_decoder.cpython-36.pyc │ │ ├── fairseq_model.cpython-36.pyc │ │ ├── fconv.cpython-36.pyc │ │ ├── fconv_lm.cpython-36.pyc │ │ ├── fconv_self_att.cpython-36.pyc │ │ ├── lightconv.cpython-36.pyc │ │ ├── lightconv_lm.cpython-36.pyc │ │ ├── lstm.cpython-36.pyc │ │ ├── masked_lm.cpython-36.pyc │ │ ├── multilingual_transformer.cpython-36.pyc │ │ ├── transformer.cpython-36.pyc │ │ ├── transformer_from_pretrained_xlm.cpython-36.pyc │ │ ├── transformer_lm.cpython-36.pyc │ │ └── wav2vec.cpython-36.pyc │ ├── composite_encoder.py │ ├── distributed_fairseq_model.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_lm.py │ ├── fconv_self_att.py │ ├── lightconv.py │ ├── lightconv_lm.py │ ├── lstm.py │ ├── masked_lm.py │ ├── multilingual_transformer.py │ ├── roberta │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── hub_interface.cpython-36.pyc │ │ │ └── model.cpython-36.pyc │ │ ├── alignment_utils.py │ │ ├── hub_interface.py │ │ └── model.py │ ├── transformer.py │ ├── transformer_from_pretrained_xlm.py │ ├── transformer_lm.py │ └── wav2vec.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── adaptive_input.cpython-36.pyc │ │ ├── adaptive_softmax.cpython-36.pyc │ │ ├── beamable_mm.cpython-36.pyc │ │ ├── character_token_embedder.cpython-36.pyc │ │ ├── conv_tbc.cpython-36.pyc │ │ ├── downsampled_multihead_attention.cpython-36.pyc │ │ ├── dynamic_convolution.cpython-36.pyc │ │ ├── gelu.cpython-36.pyc │ │ ├── grad_multiply.cpython-36.pyc │ │ ├── highway.cpython-36.pyc │ │ ├── layer_norm.cpython-36.pyc │ │ ├── learned_positional_embedding.cpython-36.pyc │ │ ├── lightweight_convolution.cpython-36.pyc │ │ ├── linearized_convolution.cpython-36.pyc │ │ ├── logsumexp_moe.cpython-36.pyc │ │ ├── mean_pool_gating_network.cpython-36.pyc │ │ ├── multihead_attention.cpython-36.pyc │ │ ├── positional_embedding.cpython-36.pyc │ │ ├── scalar_bias.cpython-36.pyc │ │ ├── sinusoidal_positional_embedding.cpython-36.pyc │ │ ├── transformer_layer.cpython-36.pyc │ │ ├── transformer_sentence_encoder.cpython-36.pyc │ │ ├── transformer_sentence_encoder_layer.cpython-36.pyc │ │ ├── unfold.cpython-36.pyc │ │ └── vggblock.cpython-36.pyc │ ├── adaptive_input.py │ ├── adaptive_softmax.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── cuda_utils.cu │ ├── downsampled_multihead_attention.py │ ├── dynamic_convolution.py │ ├── dynamicconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── dynamicconv_cuda.cpp │ │ ├── dynamicconv_cuda.cuh │ │ ├── dynamicconv_cuda_kernel.cu │ │ ├── dynamicconv_layer.py │ │ ├── dynamiconv_cpu.cpp │ │ └── setup.py │ ├── gelu.py │ ├── grad_multiply.py │ ├── highway.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── lightconv_cuda.cpp │ │ ├── lightconv_cuda.cuh │ │ ├── lightconv_cuda_kernel.cu │ │ ├── lightconv_layer.py │ │ └── setup.py │ ├── lightweight_convolution.py │ ├── linearized_convolution.py │ ├── logsumexp_moe.py │ ├── mean_pool_gating_network.py │ ├── multihead_attention.py │ ├── positional_embedding.py │ ├── scalar_bias.py │ ├── sinusoidal_positional_embedding.py │ ├── sparse_multihead_attention.py │ ├── sparse_transformer_sentence_encoder.py │ ├── sparse_transformer_sentence_encoder_layer.py │ ├── transformer_layer.py │ ├── transformer_sentence_encoder.py │ ├── transformer_sentence_encoder_layer.py │ ├── unfold.py │ └── vggblock.py ├── optim │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── adadelta.cpython-36.pyc │ │ ├── adafactor.cpython-36.pyc │ │ ├── adagrad.cpython-36.pyc │ │ ├── adam.cpython-36.pyc │ │ ├── adamax.cpython-36.pyc │ │ ├── bmuf.cpython-36.pyc │ │ ├── fairseq_optimizer.cpython-36.pyc │ │ ├── fp16_optimizer.cpython-36.pyc │ │ ├── nag.cpython-36.pyc │ │ └── sgd.cpython-36.pyc │ ├── adadelta.py │ ├── adafactor.py │ ├── adagrad.py │ ├── adam.py │ ├── adamax.py │ ├── bmuf.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── cosine_lr_scheduler.cpython-36.pyc │ │ │ ├── fairseq_lr_scheduler.cpython-36.pyc │ │ │ ├── fixed_schedule.cpython-36.pyc │ │ │ ├── inverse_square_root_schedule.cpython-36.pyc │ │ │ ├── polynomial_decay_schedule.cpython-36.pyc │ │ │ ├── reduce_lr_on_plateau.cpython-36.pyc │ │ │ ├── tri_stage_lr_scheduler.cpython-36.pyc │ │ │ └── triangular_lr_scheduler.cpython-36.pyc │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── polynomial_decay_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ ├── tri_stage_lr_scheduler.py │ │ └── triangular_lr_scheduler.py │ ├── nag.py │ └── sgd.py ├── options.py ├── pdb.py ├── progress_bar.py ├── registry.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── MLMetKE.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── MLMetKE.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── audio_pretraining.cpython-36.pyc │ │ ├── cross_lingual_lm.cpython-36.pyc │ │ ├── fairseq_task.cpython-36.pyc │ │ ├── language_modeling.cpython-36.pyc │ │ ├── legacy_masked_lm.cpython-36.pyc │ │ ├── masked_lm.cpython-36.pyc │ │ ├── multilingual_translation.cpython-36.pyc │ │ ├── semisupervised_translation.cpython-36.pyc │ │ ├── sentence_prediction.cpython-36.pyc │ │ ├── sentence_ranking.cpython-36.pyc │ │ ├── tacred_task.cpython-36.pyc │ │ ├── translation.cpython-36.pyc │ │ ├── translation_from_pretrained_xlm.cpython-36.pyc │ │ └── translation_moe.cpython-36.pyc │ ├── audio_pretraining.py │ ├── cross_lingual_lm.py │ ├── fairseq_task.py │ ├── language_modeling.py │ ├── legacy_masked_lm.py │ ├── masked_lm.py │ ├── multilingual_translation.py │ ├── semisupervised_translation.py │ ├── sentence_prediction.py │ ├── sentence_ranking.py │ ├── tacred_task.py │ ├── translation.py │ ├── translation_from_pretrained_xlm.py │ └── translation_moe.py ├── tokenizer.py ├── trainer.py └── utils.py ├── fairseqREADME.md ├── fairseq_cli ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── preprocess.cpython-36.pyc ├── eval_lm.py ├── generate.py ├── interactive.py ├── preprocess.py ├── score.py ├── setup.py └── train.py ├── fairseq_logo.png ├── generate.py ├── graphvite ├── .gitignore ├── CHANGELOG.md ├── CMakeLists.txt ├── LICENSE ├── README.md ├── asset │ ├── graph.png │ ├── knowledge_graph.png │ ├── logo │ │ ├── favicon.ico │ │ └── logo.png │ ├── visualization.png │ └── visualization │ │ ├── imagenet_hierarchy.gif │ │ └── mnist_3d.gif ├── cmake │ ├── FindGFlags.cmake │ ├── FindGlog.cmake │ └── FindPythonLibsNew.cmake ├── conda │ ├── conda_build_config.yaml │ ├── graphvite-mini │ │ ├── build.sh │ │ └── meta.yaml │ ├── graphvite │ │ ├── build.sh │ │ └── meta.yaml │ └── requirements.txt ├── config │ ├── demo │ │ ├── math.yaml │ │ └── quick_start.yaml │ ├── graph │ │ ├── deepwalk_flickr.yaml │ │ ├── deepwalk_friendster-small.yaml │ │ ├── deepwalk_friendster.yaml │ │ ├── deepwalk_hyperlink-pld.yaml │ │ ├── deepwalk_youtube.yaml │ │ ├── line_flickr.yaml │ │ ├── line_friendster-small.yaml │ │ ├── line_friendster.yaml │ │ ├── line_hyperlink-pld.yaml │ │ ├── line_youtube.yaml │ │ └── node2vec_youtube.yaml │ ├── knowledge_graph │ │ ├── complex_fb15k-237.yaml │ │ ├── complex_fb15k.yaml │ │ ├── complex_wikidata5m.yaml │ │ ├── complex_wn18.yaml │ │ ├── complex_wn18rr.yaml │ │ ├── distmult_fb15k-237.yaml │ │ ├── distmult_fb15k.yaml │ │ ├── distmult_wikidata5m.yaml │ │ ├── distmult_wn18.yaml │ │ ├── distmult_wn18rr.yaml │ │ ├── rotate_fb15k-237.yaml │ │ ├── rotate_fb15k.yaml │ │ ├── rotate_wikidata5m.yaml │ │ ├── rotate_wn18.yaml │ │ ├── rotate_wn18rr.yaml │ │ ├── simple_fb15k-237.yaml │ │ ├── simple_fb15k.yaml │ │ ├── simple_wikidata5m.yaml │ │ ├── simple_wn18.yaml │ │ ├── simple_wn18rr.yaml │ │ ├── transe_fb15k-237.yaml │ │ ├── transe_fb15k.yaml │ │ ├── transe_wikidata5m.yaml │ │ ├── transe_wn18.yaml │ │ └── transe_wn18rr.yaml │ ├── template │ │ ├── graph.yaml │ │ ├── knowledge_graph.yaml │ │ ├── visualization.yaml │ │ └── word_graph.yaml │ ├── visualization │ │ ├── largevis_imagenet.yaml │ │ ├── largevis_mnist_2d.yaml │ │ └── largevis_mnist_3d.yaml │ └── word_graph │ │ └── line_wikipedia.yaml ├── doc │ ├── Makefile │ └── source │ │ ├── api │ │ ├── application.rst │ │ ├── dataset.rst │ │ ├── graph.rst │ │ ├── optimizer.rst │ │ └── solver.rst │ │ ├── benchmark.rst │ │ ├── conf.py │ │ ├── developer │ │ ├── framework.rst │ │ ├── model.rst │ │ ├── routine.rst │ │ └── solver.rst │ │ ├── faq.rst │ │ ├── index.rst │ │ ├── install.rst │ │ ├── introduction.rst │ │ ├── link.rst │ │ ├── overview.rst │ │ ├── pretrained_model.rst │ │ ├── quick_start.rst │ │ └── user │ │ ├── auto.rst │ │ ├── command_line.rst │ │ ├── configuration.rst │ │ ├── format.rst │ │ └── python.rst ├── external │ └── .gitignore ├── include │ ├── base │ │ ├── alias_table.cuh │ │ ├── memory.h │ │ └── vector.h │ ├── bind.h │ ├── core │ │ ├── graph.h │ │ ├── optimizer.h │ │ └── solver.h │ └── util │ │ ├── common.h │ │ ├── debug.h │ │ ├── gpu.cuh │ │ ├── io.h │ │ ├── math.h │ │ └── time.h ├── python │ ├── graphvite │ │ ├── __init__.py │ │ ├── application │ │ │ ├── __init__.py │ │ │ ├── application.py │ │ │ └── network.py │ │ ├── base.py │ │ ├── cmd.py │ │ ├── dataset.py │ │ ├── graph.py │ │ ├── helper.py │ │ ├── optimizer.py │ │ ├── solver.py │ │ └── util.py │ └── setup.py └── src │ ├── CMakeLists.txt │ └── graphvite.cu ├── hubconf.py ├── interactive.py ├── ke_tool ├── evaluate_transe_inductive.py └── evaluate_transe_transductive.py ├── preprocess.py ├── score.py ├── scripts ├── __init__.py ├── average_checkpoints.py ├── build_sym_alignment.py ├── compare_namespaces.py ├── compound_split_bleu.sh ├── convert_dictionary.lua ├── convert_model.lua ├── count_docs.py ├── read_binarized.py ├── rm_pt.py ├── sacrebleu_pregen.sh ├── shard_docs.py ├── split_train_valid_docs.py ├── spm_decode.py ├── spm_encode.py ├── spm_train.py ├── wav2vec_featurize.py └── wav2vec_manifest.py ├── setup.py ├── tests ├── __init__.py ├── speech_recognition │ ├── __init__.py │ ├── asr_test_base.py │ ├── test_collaters.py │ ├── test_cross_entropy.py │ └── test_vggtransformer.py ├── test_average_checkpoints.py ├── test_backtranslation_dataset.py ├── test_binaries.py ├── test_character_token_embedder.py ├── test_concat_dataset.py ├── test_convtbc.py ├── test_dictionary.py ├── test_iterators.py ├── test_label_smoothing.py ├── test_memory_efficient_fp16.py ├── test_multi_corpus_sampled_dataset.py ├── test_noising.py ├── test_reproducibility.py ├── test_sequence_generator.py ├── test_sequence_scorer.py ├── test_sparse_multihead_attention.py ├── test_token_block_dataset.py ├── test_train.py ├── test_utils.py └── utils.py ├── train.py └── validate.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct) so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__pycache__/preprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/__pycache__/preprocess.cpython-36.pyc -------------------------------------------------------------------------------- /build/lib.linux-x86_64-3.6/fairseq/data/data_utils_fast.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/lib.linux-x86_64-3.6/fairseq/data/data_utils_fast.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /build/lib.linux-x86_64-3.6/fairseq/data/token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/lib.linux-x86_64-3.6/fairseq/data/token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /build/lib.linux-x86_64-3.6/fairseq/libbleu.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/lib.linux-x86_64-3.6/fairseq/libbleu.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /build/temp.linux-x86_64-3.6/fairseq/clib/libbleu/libbleu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/temp.linux-x86_64-3.6/fairseq/clib/libbleu/libbleu.o -------------------------------------------------------------------------------- /build/temp.linux-x86_64-3.6/fairseq/clib/libbleu/module.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/temp.linux-x86_64-3.6/fairseq/clib/libbleu/module.o -------------------------------------------------------------------------------- /build/temp.linux-x86_64-3.6/fairseq/data/data_utils_fast.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/temp.linux-x86_64-3.6/fairseq/data/data_utils_fast.o -------------------------------------------------------------------------------- /build/temp.linux-x86_64-3.6/fairseq/data/token_block_utils_fast.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/build/temp.linux-x86_64-3.6/fairseq/data/token_block_utils_fast.o -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = fairseq 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | .wy-table-responsive table td kbd { 2 | white-space: nowrap; 3 | } 4 | .wy-table-responsive table td { 5 | white-space: normal !important; 6 | } 7 | .wy-table-responsive { 8 | overflow: visible !important; 9 | } 10 | -------------------------------------------------------------------------------- /docs/criterions.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Criterions: 5 | 6 | Criterions 7 | ========== 8 | 9 | Criterions compute the loss function given the model and batch, roughly:: 10 | 11 | loss = criterion(model, batch) 12 | 13 | .. automodule:: fairseq.criterions 14 | :members: 15 | 16 | .. autoclass:: fairseq.criterions.FairseqCriterion 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss 21 | :members: 22 | :undoc-members: 23 | .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss 24 | :members: 25 | :undoc-members: 26 | .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion 27 | :members: 28 | :undoc-members: 29 | .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion 30 | :members: 31 | :undoc-members: 32 | -------------------------------------------------------------------------------- /docs/data.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. module:: fairseq.data 5 | 6 | Data Loading and Utilities 7 | ========================== 8 | 9 | .. _datasets: 10 | 11 | Datasets 12 | -------- 13 | 14 | **Datasets** define the data format and provide helpers for creating 15 | mini-batches. 16 | 17 | .. autoclass:: fairseq.data.FairseqDataset 18 | :members: 19 | .. autoclass:: fairseq.data.LanguagePairDataset 20 | :members: 21 | .. autoclass:: fairseq.data.MonolingualDataset 22 | :members: 23 | 24 | **Helper Datasets** 25 | 26 | These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and 27 | provide additional functionality: 28 | 29 | .. autoclass:: fairseq.data.BacktranslationDataset 30 | :members: 31 | .. autoclass:: fairseq.data.ConcatDataset 32 | :members: 33 | .. autoclass:: fairseq.data.RoundRobinZipDatasets 34 | :members: 35 | .. autoclass:: fairseq.data.TransformEosDataset 36 | :members: 37 | 38 | 39 | Dictionary 40 | ---------- 41 | 42 | .. autoclass:: fairseq.data.Dictionary 43 | :members: 44 | 45 | 46 | Iterators 47 | --------- 48 | 49 | .. autoclass:: fairseq.data.CountingIterator 50 | :members: 51 | .. autoclass:: fairseq.data.EpochBatchIterator 52 | :members: 53 | .. autoclass:: fairseq.data.GroupedIterator 54 | :members: 55 | .. autoclass:: fairseq.data.ShardedIterator 56 | :members: 57 | -------------------------------------------------------------------------------- /docs/docutils.conf: -------------------------------------------------------------------------------- 1 | [writers] 2 | option-limit=0 3 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fairseq documentation master file, created by 2 | sphinx-quickstart on Fri Aug 17 21:45:30 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pytorch/fairseq 7 | 8 | 9 | fairseq documentation 10 | ===================== 11 | 12 | Fairseq is a sequence modeling toolkit written in `PyTorch 13 | `_ that allows researchers and developers to 14 | train custom models for translation, summarization, language modeling and other 15 | text generation tasks. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Getting Started 20 | 21 | getting_started 22 | command_line_tools 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Extending Fairseq 27 | 28 | overview 29 | tutorial_simple_lstm 30 | tutorial_classifying_names 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | :caption: Library Reference 35 | 36 | tasks 37 | models 38 | criterions 39 | optim 40 | lr_scheduler 41 | data 42 | modules 43 | 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`search` 50 | -------------------------------------------------------------------------------- /docs/lr_scheduler.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _Learning Rate Schedulers: 5 | 6 | Learning Rate Schedulers 7 | ======================== 8 | 9 | Learning Rate Schedulers update the learning rate over the course of training. 10 | Learning rates can be updated after each update via :func:`step_update` or at 11 | epoch boundaries via :func:`step`. 12 | 13 | .. automodule:: fairseq.optim.lr_scheduler 14 | :members: 15 | 16 | .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule 21 | :members: 22 | :undoc-members: 23 | .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule 24 | :members: 25 | :undoc-members: 26 | .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule 27 | :members: 28 | :undoc-members: 29 | .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau 30 | :members: 31 | :undoc-members: 32 | .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule 33 | :members: 34 | :undoc-members: 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fairseq 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may 5 | be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`. 6 | 7 | .. automodule:: fairseq.modules 8 | :members: 9 | :undoc-members: 10 | -------------------------------------------------------------------------------- /docs/optim.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | .. _optimizers: 5 | 6 | Optimizers 7 | ========== 8 | 9 | Optimizers update the Model parameters based on the gradients. 10 | 11 | .. automodule:: fairseq.optim 12 | :members: 13 | 14 | .. autoclass:: fairseq.optim.FairseqOptimizer 15 | :members: 16 | :undoc-members: 17 | 18 | .. autoclass:: fairseq.optim.adadelta.Adadelta 19 | :members: 20 | :undoc-members: 21 | .. autoclass:: fairseq.optim.adagrad.Adagrad 22 | :members: 23 | :undoc-members: 24 | .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor 25 | :members: 26 | :undoc-members: 27 | .. autoclass:: fairseq.optim.adam.FairseqAdam 28 | :members: 29 | :undoc-members: 30 | .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer 31 | :members: 32 | :undoc-members: 33 | .. autoclass:: fairseq.optim.nag.FairseqNAG 34 | :members: 35 | :undoc-members: 36 | .. autoclass:: fairseq.optim.sgd.SGD 37 | :members: 38 | :undoc-members: 39 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<2.0 2 | sphinx-argparse 3 | -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/.DS_Store -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | !*/*.sh 2 | !*/*.md 3 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/CoLA.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=10336 2 | WARMUP_UPDATES=520 3 | LR=5e-05 4 | NUM_CLASSES=2 5 | MAX_SENTENCES=64 # Batch size. 6 | ROBERTA_PATH=path_to_KEPLER_original_checkpoint 7 | 8 | fairseq-train CoLA-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --save-dir CoLA-ckpt \ 12 | --max-sentences $MAX_SENTENCES \ 13 | --max-tokens 8800 \ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 20 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/MNLI.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=123873 2 | WARMUP_UPDATES=7432 3 | LR=1e-05 4 | NUM_CLASSES=3 5 | MAX_SENTENCES=32 # Batch size. 6 | ROBERTA_PATH=path_to_KEPLER_original_checkpoint 7 | 8 | 9 | fairseq-train MNLI-bin/ \ 10 | --restore-file $ROBERTA_PATH \ 11 | --max-positions 512 \ 12 | --save-dir MNLI-ckpt-ori \ 13 | --max-sentences $MAX_SENTENCES \ 14 | --max-tokens 8800 \ 15 | --task sentence_prediction \ 16 | --reset-optimizer --reset-dataloader --reset-meters \ 17 | --required-batch-size-multiple 1 \ 18 | --init-token 0 --separator-token 2 \ 19 | --arch roberta_base \ 20 | --criterion sentence_prediction \ 21 | --num-classes $NUM_CLASSES \ 22 | --dropout 0.1 --attention-dropout 0.1 \ 23 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 24 | --clip-norm 0.0 \ 25 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 26 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 27 | --max-epoch 10 \ 28 | --find-unused-parameters \ 29 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 30 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/MRPC.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=2148 2 | WARMUP_UPDATES=264 3 | LR=1e-05 4 | NUM_CLASSES=2 5 | MAX_SENTENCES=16 # Batch size. 6 | ROBERTA_PATH=MNLI-ckpt/checkpoint_best.pt #Starting from MNLI checkpoint 7 | 8 | fairseq-train MRPC-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --max-sentences $MAX_SENTENCES \ 12 | --max-tokens 8800 \ 13 | --save-dir ./MRPC-ckpt\ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 10 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/QNLI.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=33112 2 | WARMUP_UPDATES=1986 3 | LR=1e-05 4 | NUM_CLASSES=2 5 | MAX_SENTENCES=32 # Batch size. 6 | ROBERTA_PATH=path_to_KEPLER_original_checkpoint 7 | 8 | fairseq-train QNLI-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --save-dir QNLI-ckpt \ 12 | --max-sentences $MAX_SENTENCES \ 13 | --max-tokens 4400 \ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 10 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/QQP.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=113272 2 | WARMUP_UPDATES=28318 3 | LR=1e-05 4 | NUM_CLASSES=2 5 | MAX_SENTENCES=32 # Batch size. 6 | ROBERTA_PATH=path_to_KEPLER_original_checkpoint 7 | 8 | fairseq-train QQP-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --save-dir QQP-ckpt \ 12 | --max-sentences $MAX_SENTENCES \ 13 | --max-tokens 4400 \ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 10 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/RTE.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=4036 2 | WARMUP_UPDATES=492 3 | LR=8e-04 4 | NUM_CLASSES=2 5 | MAX_SENTENCES=64 # Batch size. 6 | ROBERTA_PATH=MNLI-ckpt/checkpoint_best.pt #Strating from MNLI checkpoint 7 | 8 | fairseq-train RTE-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --max-sentences $MAX_SENTENCES \ 12 | --max-tokens 8800 \ 13 | --save-dir ./RTE-ckpt\ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 20 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/SST-2.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=20935 2 | WARMUP_UPDATES=1256 3 | LR=1e-05 4 | NUM_CLASSES=2 5 | MAX_SENTENCES=32 # Batch size. 6 | ROBERTA_PATH=path_to_KEPLER_original_checkpoint 7 | 8 | fairseq-train SST-2-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --save-dir SST-2-ckpt \ 12 | --max-sentences $MAX_SENTENCES \ 13 | --max-tokens 4400 \ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 10 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/GLUE/STS-B.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=3598 2 | WARMUP_UPDATES=214 3 | LR=2e-05 4 | NUM_CLASSES=1 5 | MAX_SENTENCES=16 # Batch size. 6 | ROBERTA_PATH=MNLI-ckpt/checkpoint_best.pt #Starting from MNLI checkpoint 7 | 8 | fairseq-train STS-B-bin/ \ 9 | --restore-file $ROBERTA_PATH \ 10 | --max-positions 512 \ 11 | --save-dir STS-B-ckpt \ 12 | --max-sentences $MAX_SENTENCES \ 13 | --max-tokens 8800 \ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 40 \ 27 | --find-unused-parameters \ 28 | --best-checkpoint-metric loss --regression-target; 29 | -------------------------------------------------------------------------------- /examples/KEPLER/OpenEntity/pytorch_transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_bert import BertConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 31 | } 32 | 33 | 34 | class RobertaConfig(BertConfig): 35 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | -------------------------------------------------------------------------------- /examples/KEPLER/OpenEntity/pytorch_transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/KEPLER/OpenEntity/pytorch_transformers/tests/__init__.py -------------------------------------------------------------------------------- /examples/KEPLER/OpenEntity/pytorch_transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /examples/KEPLER/OpenEntity/pytorch_transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /examples/KEPLER/OpenEntity/pytorch_transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/KEPLER/OpenEntity/pytorch_transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /examples/KEPLER/OpenEntity/run_openentity.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=path_to_OpenEntity 2 | 3 | python run_typing.py \ 4 | --model_type roberta \ 5 | --model_name_or_path path_to_converted_KEPLER \ 6 | --task_name typing \ 7 | --do_train \ 8 | --do_eval \ 9 | --do_lower_case \ 10 | --data_dir $DATA_DIR \ 11 | --max_seq_length 128 \ 12 | --per_gpu_train_batch_size 32 \ 13 | --learning_rate 3e-5 \ 14 | --num_train_epochs 40 \ 15 | --output_dir path_to_output_checkpoint \ 16 | --evaluate_during_training \ 17 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __version__ = '0.8.0' 7 | 8 | import examples.noisychannel # noqa 9 | -------------------------------------------------------------------------------- /examples/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /examples/language_model/conv_lm/README.md: -------------------------------------------------------------------------------- 1 | # Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017) 2 | 3 | ## Example usage 4 | 5 | First download and preprocess the data following the main [language modeling 6 | README](../README.md). 7 | 8 | Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103` 9 | architecture: 10 | ```bash 11 | fairseq-train --task language_modeling \ 12 | data-bin/wikitext-103 \ 13 | --save-dir checkpoints/fconv_wikitext-103 \ 14 | --arch fconv_lm_dauphin_wikitext103 \ 15 | --max-epoch 35 \ --optimizer nag \ 16 | --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ 17 | --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \ 18 | --adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024 \ 19 | --ddp-backend=no_c10d 20 | ``` 21 | 22 | And evaluate with: 23 | ```bash 24 | fairseq-eval-lm data-bin/wikitext-103 --path checkpoints/fconv_wiki103/checkpoint_best.pt 25 | ``` 26 | 27 | ## Citation 28 | 29 | ```bibtex 30 | @inproceedings{dauphin2017language, 31 | title={Language Modeling with Gated Convolutional Networks}, 32 | author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David}, 33 | booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70}, 34 | pages={933--941}, 35 | year={2017}, 36 | organization={JMLR} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /examples/language_model/transformer_lm/README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018) 2 | 3 | ## Pre-trained models 4 | 5 | Description | Parameters | Dataset | Model and Test set(s) 6 | ---|---:|---|--- 7 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2) 8 | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2) 9 | 10 | ## Example usage 11 | 12 | See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103 13 | using the `transformer_lm_wiki103` model architecture. 14 | 15 | ## Citation 16 | 17 | ```bibtex 18 | @inproceedings{ 19 | baevski2018adaptive, 20 | title={Adaptive Input Representations for Neural Language Modeling}, 21 | author={Alexei Baevski and Michael Auli}, 22 | booktitle={International Conference on Learning Representations}, 23 | year={2019}, 24 | url={https://openreview.net/forum?id=ByxZX20qFQ}, 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/noisychannel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .rerank_options import * # noqa 7 | -------------------------------------------------------------------------------- /examples/noisychannel/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/noisychannel/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /examples/noisychannel/__pycache__/rerank_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/noisychannel/__pycache__/rerank_options.cpython-36.pyc -------------------------------------------------------------------------------- /examples/roberta/__pycache__/multiprocessing_bpe_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/examples/roberta/__pycache__/multiprocessing_bpe_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /examples/roberta/commonsense_qa/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import commonsense_qa_task # noqa 7 | -------------------------------------------------------------------------------- /examples/roberta/commonsense_qa/download_cqa_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | OUTDIR=data/CommonsenseQA 8 | 9 | mkdir -p $OUTDIR 10 | 11 | wget -O $OUTDIR/train.jsonl https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl 12 | wget -O $OUTDIR/valid.jsonl https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl 13 | wget -O $OUTDIR/test.jsonl https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl 14 | wget -O $OUTDIR/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt 15 | -------------------------------------------------------------------------------- /examples/roberta/wsc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import wsc_criterion # noqa 7 | from . import wsc_task # noqa 8 | -------------------------------------------------------------------------------- /examples/speech_recognition/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tasks, criterions, models # noqa 2 | -------------------------------------------------------------------------------- /examples/speech_recognition/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | for file in os.listdir(os.path.dirname(__file__)): 5 | if file.endswith('.py') and not file.startswith('_'): 6 | criterion_name = file[:file.find('.py')] 7 | importlib.import_module('examples.speech_recognition.criterions.' + criterion_name) 8 | -------------------------------------------------------------------------------- /examples/speech_recognition/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .asr_dataset import AsrDataset 7 | 8 | __all__ = [ 9 | 'AsrDataset', 10 | ] 11 | -------------------------------------------------------------------------------- /examples/speech_recognition/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | for file in os.listdir(os.path.dirname(__file__)): 5 | if file.endswith('.py') and not file.startswith('_'): 6 | model_name = file[:file.find('.py')] 7 | importlib.import_module('examples.speech_recognition.models.' + model_name) 8 | -------------------------------------------------------------------------------- /examples/speech_recognition/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | for file in os.listdir(os.path.dirname(__file__)): 5 | if file.endswith('.py') and not file.startswith('_'): 6 | task_name = file[:file.find('.py')] 7 | importlib.import_module('examples.speech_recognition.tasks.' + task_name) 8 | -------------------------------------------------------------------------------- /fairseq.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fairseq.egg-info/entry_points.txt: -------------------------------------------------------------------------------- 1 | [console_scripts] 2 | fairseq-eval-lm = fairseq_cli.eval_lm:cli_main 3 | fairseq-generate = fairseq_cli.generate:cli_main 4 | fairseq-interactive = fairseq_cli.interactive:cli_main 5 | fairseq-preprocess = fairseq_cli.preprocess:cli_main 6 | fairseq-score = fairseq_cli.score:main 7 | fairseq-train = fairseq_cli.train:cli_main 8 | fairseq-validate = fairseq_cli.validate:cli_main 9 | 10 | -------------------------------------------------------------------------------- /fairseq.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fairseq.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | fastBPE 3 | numpy 4 | regex 5 | torch 6 | tqdm 7 | -------------------------------------------------------------------------------- /fairseq.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | examples 2 | fairseq 3 | fairseq_cli 4 | tests 5 | -------------------------------------------------------------------------------- /fairseq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq.gif -------------------------------------------------------------------------------- /fairseq/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/.DS_Store -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = ['pdb'] 7 | __version__ = '0.8.0' 8 | 9 | import fairseq.criterions # noqa 10 | import fairseq.models # noqa 11 | import fairseq.modules # noqa 12 | import fairseq.optim # noqa 13 | import fairseq.optim.lr_scheduler # noqa 14 | import fairseq.pdb # noqa 15 | import fairseq.tasks # noqa 16 | -------------------------------------------------------------------------------- /fairseq/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/binarizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/binarizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/checkpoint_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/checkpoint_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/distributed_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/distributed_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/hub_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/hub_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/legacy_distributed_data_parallel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/pdb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/pdb.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/registry.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/registry.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/search.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/search.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/sequence_generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/sequence_generator.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.criterions.fairseq_criterion import FairseqCriterion 11 | 12 | 13 | build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( 14 | '--criterion', 15 | base_class=FairseqCriterion, 16 | default='cross_entropy', 17 | ) 18 | 19 | 20 | # automatically import any Python files in the criterions/ directory 21 | for file in os.listdir(os.path.dirname(__file__)): 22 | if file.endswith('.py') and not file.startswith('_'): 23 | module = file[:file.find('.py')] 24 | importlib.import_module('fairseq.criterions.' + module) 25 | -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/MLMetKE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/MLMetKE.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/adaptive_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/adaptive_loss.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/binary_cross_entropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/binary_cross_entropy.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/composite_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/composite_loss.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/cross_entropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/cross_entropy.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/fairseq_criterion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/fairseq_criterion.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/legacy_masked_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/masked_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/masked_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/only_ke.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/only_ke.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/relation_extraction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/relation_extraction.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/sentence_prediction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/sentence_prediction.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/sentence_prediction_debug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/sentence_prediction_debug.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/__pycache__/sentence_ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/criterions/__pycache__/sentence_ranking.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.nn.modules.loss import _Loss 7 | 8 | 9 | class FairseqCriterion(_Loss): 10 | 11 | def __init__(self, args, task): 12 | super().__init__() 13 | self.args = args 14 | self.task = task 15 | self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add criterion-specific arguments to the parser.""" 20 | pass 21 | 22 | @classmethod 23 | def build_criterion(cls, args, task): 24 | return cls(args, task) 25 | 26 | def forward(self, model, sample, reduce=True): 27 | """Compute the loss for the given sample. 28 | 29 | Returns a tuple with three elements: 30 | 1) the loss 31 | 2) the sample size, which is used as the denominator for the gradient 32 | 3) logging outputs to display while training 33 | """ 34 | raise NotImplementedError 35 | 36 | @staticmethod 37 | def aggregate_logging_outputs(logging_outputs): 38 | """Aggregate logging outputs from data parallel training.""" 39 | raise NotImplementedError 40 | 41 | @staticmethod 42 | def grad_denom(sample_sizes): 43 | """Compute the gradient denominator for a set of sample sizes.""" 44 | return sum(sample_sizes) 45 | -------------------------------------------------------------------------------- /fairseq/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/backtranslation_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/backtranslation_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/base_wrapper_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/base_wrapper_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/concat_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/concat_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/concat_sentences_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/concat_sentences_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/dictionary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/dictionary.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/fairseq_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/fairseq_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/fake_numel_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/fake_numel_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/id_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/id_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/indexed_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/indexed_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/iterators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/iterators.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/ke_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/ke_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/ke_negative_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/ke_negative_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/language_pair_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/language_pair_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/list_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/list_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/lm_context_window_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/lm_context_window_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/lru_cache_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/lru_cache_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/mask_tokens_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/mask_tokens_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/monolingual_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/monolingual_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/nested_dictionary_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/noising.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/noising.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/num_samples_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/num_samples_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/numel_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/numel_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/offset_tokens_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/offset_tokens_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/pad_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/pad_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/plasma_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/plasma_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/prepend_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/prepend_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/prepend_token_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/prepend_token_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/raw_label_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/raw_label_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/replace_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/replace_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/round_robin_zip_datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/sharded_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/sharded_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/sort_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/sort_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/strip_token_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/strip_token_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/subsample_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/subsample_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/token_block_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/token_block_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/transform_eos_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/transform_eos_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/__pycache__/truncate_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/__pycache__/truncate_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/audio/__init__.py -------------------------------------------------------------------------------- /fairseq/data/audio/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/audio/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/base_wrapper_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class BaseWrapperDataset(FairseqDataset): 12 | 13 | def __init__(self, dataset): 14 | super().__init__() 15 | self.dataset = dataset 16 | 17 | def __getitem__(self, index): 18 | return self.dataset[index] 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def collater(self, samples): 24 | if hasattr(self.dataset, 'collater'): 25 | return self.dataset.collater(samples) 26 | else: 27 | return default_collate(samples) 28 | 29 | @property 30 | def sizes(self): 31 | return self.dataset.sizes 32 | 33 | def num_tokens(self, index): 34 | return self.dataset.num_tokens(index) 35 | 36 | def size(self, index): 37 | return self.dataset.size(index) 38 | 39 | def ordered_indices(self): 40 | return self.dataset.ordered_indices() 41 | 42 | @property 43 | def supports_prefetch(self): 44 | return getattr(self.dataset, 'supports_prefetch', False) 45 | 46 | def prefetch(self, indices): 47 | self.dataset.prefetch(indices) 48 | 49 | def set_epoch(self, epoch): 50 | super().set_epoch(epoch) 51 | if hasattr(self.dataset, 'set_epoch'): 52 | self.dataset.set_epoch(epoch) 53 | -------------------------------------------------------------------------------- /fairseq/data/concat_sentences_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class ConcatSentencesDataset(FairseqDataset): 12 | 13 | def __init__(self, *datasets): 14 | super().__init__() 15 | self.datasets = datasets 16 | assert all(len(ds) == len(datasets[0]) for ds in datasets), \ 17 | 'datasets must have the same length' 18 | 19 | def __getitem__(self, index): 20 | return torch.cat([ds[index] for ds in self.datasets]) 21 | 22 | def __len__(self): 23 | return len(self.datasets[0]) 24 | 25 | def collater(self, samples): 26 | return self.datasets[0].collater(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return sum(ds.sizes for ds in self.datasets) 31 | 32 | def num_tokens(self, index): 33 | return sum(ds.num_tokens(index) for ds in self.datasets) 34 | 35 | def size(self, index): 36 | return sum(ds.size(index) for ds in self.datasets) 37 | 38 | def ordered_indices(self): 39 | return self.datasets[0].ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return any( 44 | getattr(ds, 'supports_prefetch', False) for ds in self.datasets 45 | ) 46 | 47 | def prefetch(self, indices): 48 | for ds in self.datasets: 49 | if getattr(ds, 'supports_prefetch', False): 50 | ds.prefetch(indices) 51 | -------------------------------------------------------------------------------- /fairseq/data/data_utils_fast.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/data_utils_fast.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( 14 | '--tokenizer', 15 | default=None, 16 | ) 17 | 18 | 19 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry( 20 | '--bpe', 21 | default=None, 22 | ) 23 | 24 | 25 | # automatically import any Python files in the encoders/ directory 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith('.py') and not file.startswith('_'): 28 | module = file[:file.find('.py')] 29 | importlib.import_module('fairseq.data.encoders.' + module) 30 | -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/fastbpe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/fastbpe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/space_tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/encoders/fastbpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('fastbpe') 11 | class fastBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--bpe-codes', type=str, 17 | help='path to fastBPE BPE') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | if args.bpe_codes is None: 22 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt') 23 | codes = file_utils.cached_path(args.bpe_codes) 24 | try: 25 | import fastBPE 26 | self.bpe = fastBPE.fastBPE(codes) 27 | self.bpe_symbol = "@@ " 28 | except ImportError: 29 | raise ImportError('Please install fastBPE with: pip install fastBPE') 30 | 31 | def encode(self, x: str) -> str: 32 | return self.bpe.apply([x])[0] 33 | 34 | def decode(self, x: str) -> str: 35 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 36 | -------------------------------------------------------------------------------- /fairseq/data/encoders/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer('nltk') 10 | class NLTKTokenizer(object): 11 | 12 | def __init__(self, source_lang=None, target_lang=None): 13 | try: 14 | from nltk.tokenize import word_tokenize 15 | self.word_tokenize = word_tokenize 16 | except ImportError: 17 | raise ImportError('Please install nltk with: pip install nltk') 18 | 19 | def encode(self, x: str) -> str: 20 | return ' '.join(self.word_tokenize(x)) 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /fairseq/data/encoders/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('sentencepiece') 11 | class SentencepieceBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--sentencepiece-vocab', type=str, 17 | help='path to sentencepiece vocab') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | vocab = file_utils.cached_path(args.sentencepiece_vocab) 22 | try: 23 | import sentencepiece as spm 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(vocab) 26 | except ImportError: 27 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 28 | 29 | def encode(self, x: str) -> str: 30 | return ' '.join(self.sp.EncodeAsPieces(x)) 31 | 32 | def decode(self, x: str) -> str: 33 | return x.replace(' ', '').replace('\u2581', ' ').strip() 34 | -------------------------------------------------------------------------------- /fairseq/data/encoders/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | 10 | 11 | @register_tokenizer('space') 12 | class SpaceTokenizer(object): 13 | 14 | def __init__(self, source_lang=None, target_lang=None): 15 | self.space_tok = re.compile(r"\s+") 16 | 17 | def encode(self, x: str) -> str: 18 | return self.space_tok.sub(' ', x) 19 | 20 | def decode(self, x: str) -> str: 21 | return x 22 | -------------------------------------------------------------------------------- /fairseq/data/fake_numel_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import FairseqDataset 10 | 11 | 12 | class FakeNumelDataset(FairseqDataset): 13 | 14 | def __init__(self, cnt, reduce=False): 15 | super().__init__() 16 | self.cnt = cnt 17 | self.reduce = reduce 18 | 19 | def __getitem__(self, index): 20 | return self.cnt[index] 21 | 22 | def __len__(self): 23 | return len(self.cnt) 24 | 25 | def collater(self, samples): 26 | if self.reduce: 27 | return sum(samples) 28 | else: 29 | #print(samples) 30 | #print("________________") 31 | return torch.tensor(samples) 32 | -------------------------------------------------------------------------------- /fairseq/data/id_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class IdDataset(FairseqDataset): 12 | 13 | def __getitem__(self, index): 14 | return index 15 | 16 | def __len__(self): 17 | return 0 18 | 19 | def collater(self, samples): 20 | return torch.tensor(samples) 21 | -------------------------------------------------------------------------------- /fairseq/data/ke_negative_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Xiaozhi Wang 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import OrderedDict 7 | 8 | import numpy as np 9 | import torch 10 | from . import FairseqDataset, BaseWrapperDataset 11 | 12 | 13 | class KeNegDataset(BaseWrapperDataset): 14 | 15 | def __init__(self, dataset ,args): 16 | super().__init__(dataset) 17 | self.ns=args.negative_sample_size 18 | 19 | def _map_indices(self, indices): 20 | tmp=[] 21 | for index in indices: 22 | tmp=tmp+list(range(index*self.ns,(index+1)*self.ns)) 23 | return tmp 24 | 25 | def __getitem__(self, index): 26 | tmp=self._map_indices([index]) 27 | return [self.dataset[x] for x in tmp] 28 | 29 | def collater(self,samples): 30 | return self.dataset.collater([y for x in samples for y in x]) 31 | 32 | def __len__(self): 33 | return len(self.dataset)//self.ns 34 | -------------------------------------------------------------------------------- /fairseq/data/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary 7 | from .block_pair_dataset import BlockPairDataset 8 | from .masked_lm_dataset import MaskedLMDataset 9 | 10 | __all__ = [ 11 | 'BertDictionary', 12 | 'BlockPairDataset', 13 | 'MaskedLMDataset', 14 | 'MaskedLMDictionary', 15 | ] 16 | -------------------------------------------------------------------------------- /fairseq/data/legacy/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/legacy/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ListDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, sizes=None): 12 | super().__init__(dataset) 13 | self._sizes = sizes 14 | 15 | def collater(self, samples): 16 | return samples 17 | 18 | @property 19 | def sizes(self): 20 | return self._sizes 21 | 22 | def num_tokens(self, index): 23 | return self.sizes[index] 24 | 25 | def size(self, index): 26 | return self.sizes[index] 27 | 28 | def set_epoch(self, epoch): 29 | pass 30 | -------------------------------------------------------------------------------- /fairseq/data/lru_cache_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import lru_cache 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class LRUCacheDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | 16 | @lru_cache(maxsize=8) 17 | def __getitem__(self, index): 18 | return self.dataset[index] 19 | 20 | @lru_cache(maxsize=8) 21 | def collater(self, samples): 22 | return self.dataset.collater(samples) 23 | -------------------------------------------------------------------------------- /fairseq/data/num_samples_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqDataset 7 | 8 | 9 | class NumSamplesDataset(FairseqDataset): 10 | 11 | def __getitem__(self, index): 12 | return 1 13 | 14 | def __len__(self): 15 | return 0 16 | 17 | def collater(self, samples): 18 | return sum(samples) 19 | -------------------------------------------------------------------------------- /fairseq/data/numel_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class NumelDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, reduce=False): 15 | super().__init__(dataset) 16 | self.reduce = reduce 17 | 18 | def __getitem__(self, index): 19 | item = self.dataset[index] 20 | if torch.is_tensor(item): 21 | return torch.numel(item) 22 | else: 23 | return np.size(item) 24 | 25 | def __len__(self): 26 | return len(self.dataset) 27 | 28 | def collater(self, samples): 29 | if self.reduce: 30 | return sum(samples) 31 | else: 32 | return torch.tensor(samples) 33 | -------------------------------------------------------------------------------- /fairseq/data/offset_tokens_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class OffsetTokensDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, offset): 12 | super().__init__(dataset) 13 | self.offset = offset 14 | 15 | def __getitem__(self, idx): 16 | return self.dataset[idx] + self.offset 17 | -------------------------------------------------------------------------------- /fairseq/data/pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import data_utils 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class PadDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, pad_idx, left_pad): 14 | super().__init__(dataset) 15 | self.pad_idx = pad_idx 16 | self.left_pad = left_pad 17 | 18 | def collater(self, samples): 19 | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) 20 | 21 | 22 | class LeftPadDataset(PadDataset): 23 | 24 | def __init__(self, dataset, pad_idx): 25 | super().__init__(dataset, pad_idx, left_pad=True) 26 | 27 | 28 | class RightPadDataset(PadDataset): 29 | 30 | def __init__(self, dataset, pad_idx): 31 | super().__init__(dataset, pad_idx, left_pad=False) 32 | -------------------------------------------------------------------------------- /fairseq/data/prepend_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): 14 | super().__init__(dataset) 15 | self.prepend_getter = prepend_getter 16 | self.ensure_first_token = ensure_first_token_is 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataset[idx] 20 | is_tuple = isinstance(item, tuple) 21 | src = item[0] if is_tuple else item 22 | 23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token 24 | prepend_idx = self.prepend_getter(self.dataset, idx) 25 | assert isinstance(prepend_idx, int) 26 | src[0] = prepend_idx 27 | item = tuple((src,) + item[1:]) if is_tuple else src 28 | return item 29 | -------------------------------------------------------------------------------- /fairseq/data/prepend_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependTokenDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, token=None): 15 | super().__init__(dataset) 16 | self.token = token 17 | if token is not None: 18 | self._sizes = np.array(dataset.sizes) + 1 19 | else: 20 | self._sizes = dataset.sizes 21 | 22 | def __getitem__(self, idx): 23 | item = self.dataset[idx] 24 | if self.token is not None: 25 | item = torch.cat([item.new([self.token]), item]) 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return self._sizes 31 | 32 | def num_tokens(self, index): 33 | n = self.dataset.num_tokens(index) 34 | if self.token is not None: 35 | n += 1 36 | return n 37 | 38 | def size(self, index): 39 | n = self.dataset.size(index) 40 | if self.token is not None: 41 | n += 1 42 | return n 43 | -------------------------------------------------------------------------------- /fairseq/data/raw_label_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class RawLabelDataset(FairseqDataset): 12 | 13 | def __init__(self, labels): 14 | super().__init__() 15 | self.labels = labels 16 | 17 | def __getitem__(self, index): 18 | return self.labels[index] 19 | 20 | def __len__(self): 21 | return len(self.labels) 22 | 23 | def collater(self, samples): 24 | return torch.tensor(samples) 25 | -------------------------------------------------------------------------------- /fairseq/data/replace_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ReplaceDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, replace_map, offset=0): 11 | super().__init__(dataset) 12 | assert len(replace_map) > 0 13 | self.replace_map = replace_map 14 | self.offset = offset 15 | 16 | def __getitem__(self, index): 17 | item = self.dataset[index] 18 | is_tuple = isinstance(item, tuple) 19 | src = item[0] if is_tuple else item 20 | 21 | for k, v in self.replace_map.items(): 22 | src_off = src[self.offset:] 23 | src_off.masked_fill_(src_off == k, v) 24 | 25 | item = tuple((src,) + item[1:]) if is_tuple else src 26 | return item 27 | -------------------------------------------------------------------------------- /fairseq/data/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class SortDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, sort_order): 14 | super().__init__(dataset) 15 | if not isinstance(sort_order, (list, tuple)): 16 | sort_order = [sort_order] 17 | self.sort_order = sort_order 18 | 19 | assert all(len(so) == len(dataset) for so in sort_order) 20 | 21 | def ordered_indices(self): 22 | return np.lexsort(self.sort_order) 23 | -------------------------------------------------------------------------------- /fairseq/data/strip_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class StripTokenDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, id_to_strip): 12 | super().__init__(dataset) 13 | self.id_to_strip = id_to_strip 14 | 15 | def __getitem__(self, index): 16 | item = self.dataset[index] 17 | return item[item.ne(self.id_to_strip)] 18 | -------------------------------------------------------------------------------- /fairseq/data/token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/data/token_block_utils_fast.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/data/truncate_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class TruncateDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, truncation_length): 14 | super().__init__(dataset) 15 | assert truncation_length is not None 16 | self.truncation_length = truncation_length 17 | self.dataset = dataset 18 | 19 | def __getitem__(self, index): 20 | item = self.dataset[index] 21 | item_len = item.size(0) 22 | if item_len > self.truncation_length: 23 | item = item[:self.truncation_length] 24 | return item 25 | 26 | @property 27 | def sizes(self): 28 | return np.minimum(self.dataset.sizes, self.truncation_length) 29 | 30 | def __len__(self): 31 | return len(self.dataset) 32 | -------------------------------------------------------------------------------- /fairseq/libbleu.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/libbleu.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /fairseq/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/.DS_Store -------------------------------------------------------------------------------- /fairseq/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/composite_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/composite_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/distributed_fairseq_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/distributed_fairseq_model.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fairseq_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fairseq_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fairseq_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fairseq_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fairseq_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fairseq_model.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fconv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fconv.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fconv_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fconv_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/fconv_self_att.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/fconv_self_att.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/lightconv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/lightconv.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/lightconv_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/lightconv_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/lstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/lstm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/masked_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/masked_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/multilingual_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/multilingual_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/transformer_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/transformer_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/__pycache__/wav2vec.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/__pycache__/wav2vec.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class FairseqEncoder(nn.Module): 10 | """Base class for encoders.""" 11 | 12 | def __init__(self, dictionary): 13 | super().__init__() 14 | self.dictionary = dictionary 15 | 16 | def forward(self, src_tokens, src_lengths=None, **kwargs): 17 | """ 18 | Args: 19 | src_tokens (LongTensor): tokens in the source language of shape 20 | `(batch, src_len)` 21 | src_lengths (LongTensor): lengths of each source sentence of shape 22 | `(batch)` 23 | """ 24 | raise NotImplementedError 25 | 26 | def reorder_encoder_out(self, encoder_out, new_order): 27 | """ 28 | Reorder encoder output according to `new_order`. 29 | 30 | Args: 31 | encoder_out: output from the ``forward()`` method 32 | new_order (LongTensor): desired order 33 | 34 | Returns: 35 | `encoder_out` rearranged according to `new_order` 36 | """ 37 | raise NotImplementedError 38 | 39 | def max_positions(self): 40 | """Maximum input length supported by the encoder.""" 41 | return 1e6 # an arbitrary large number 42 | 43 | def upgrade_state_dict(self, state_dict): 44 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 45 | return state_dict 46 | -------------------------------------------------------------------------------- /fairseq/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | -------------------------------------------------------------------------------- /fairseq/models/roberta/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/roberta/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/roberta/__pycache__/hub_interface.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/roberta/__pycache__/hub_interface.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/models/roberta/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/models/roberta/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/adaptive_input.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/adaptive_input.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/adaptive_softmax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/adaptive_softmax.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/beamable_mm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/beamable_mm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/character_token_embedder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/character_token_embedder.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/conv_tbc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/conv_tbc.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/dynamic_convolution.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/dynamic_convolution.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/gelu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/gelu.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/grad_multiply.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/grad_multiply.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/highway.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/highway.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/layer_norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/layer_norm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/learned_positional_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/learned_positional_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/lightweight_convolution.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/lightweight_convolution.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/linearized_convolution.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/linearized_convolution.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/logsumexp_moe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/logsumexp_moe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/mean_pool_gating_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/mean_pool_gating_network.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/multihead_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/multihead_attention.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/positional_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/positional_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/scalar_bias.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/scalar_bias.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/transformer_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/transformer_layer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/unfold.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/unfold.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/__pycache__/vggblock.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/modules/__pycache__/vggblock.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.nn.modules.utils import _single 8 | 9 | 10 | class ConvTBC(torch.nn.Module): 11 | """1D convolution over an input of shape (time x batch x channel) 12 | 13 | The implementation uses gemm to perform the convolution. This implementation 14 | is faster than cuDNN for small kernel sizes. 15 | """ 16 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 17 | super(ConvTBC, self).__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.kernel_size = _single(kernel_size) 21 | self.padding = _single(padding) 22 | 23 | self.weight = torch.nn.Parameter(torch.Tensor( 24 | self.kernel_size[0], in_channels, out_channels)) 25 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 26 | 27 | def forward(self, input): 28 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 29 | 30 | def __repr__(self): 31 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 32 | ', padding={padding}') 33 | if self.bias is None: 34 | s += ', bias=False' 35 | s += ')' 36 | return s.format(name=self.__class__.__name__, **self.__dict__) 37 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamicconv_layer import DynamicconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector dynamicconv_cpu_forward( 5 | float* input, 6 | float* filters, 7 | int padding_l); 8 | 9 | std::vector dynamicconv_cpu_backward( 10 | float* gradOutput, 11 | int padding_l, 12 | float* input, 13 | float* filters); 14 | 15 | std::vector dynamicconv_forward( 16 | float* input, 17 | float* filters, 18 | int padding_l) { 19 | 20 | return dynamicconv_cpu_forward(input, filters, padding_l); 21 | } 22 | 23 | std::vector dynamicconv_backward( 24 | float* gradOutput, 25 | int padding_l, 26 | float* input, 27 | float* filters) { 28 | 29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); 34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); 35 | } 36 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='dynamicconv_layer', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='dynamicconv_cuda', 15 | sources=[ 16 | 'dynamicconv_cuda.cpp', 17 | 'dynamicconv_cuda_kernel.cu', 18 | ], 19 | ), 20 | ], 21 | cmdclass={ 22 | 'build_ext': BuildExtension 23 | }) 24 | -------------------------------------------------------------------------------- /fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | 14 | 15 | def gelu_accurate(x): 16 | if not hasattr(gelu_accurate, "_a"): 17 | gelu_accurate._a = math.sqrt(2 / math.pi) 18 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 19 | 20 | 21 | def gelu(x: torch.Tensor) -> torch.Tensor: 22 | if hasattr(torch.nn.functional, 'gelu'): 23 | return torch.nn.functional.gelu(x.float()).type_as(x) 24 | else: 25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 26 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class GradMultiply(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.scale = scale 13 | res = x.new(x) 14 | return res 15 | 16 | @staticmethod 17 | def backward(ctx, grad): 18 | return grad * ctx.scale, None 19 | -------------------------------------------------------------------------------- /fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | #Xiaozhi Wang: original eps: 1e-5 9 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 10 | if not export and torch.cuda.is_available(): 11 | try: 12 | from apex.normalization import FusedLayerNorm 13 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 14 | except ImportError: 15 | pass 16 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 17 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lightconv_layer import LightconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='lightconv_layer', 12 | ext_modules=[ 13 | CUDAExtension('lightconv_cuda', [ 14 | 'lightconv_cuda.cpp', 15 | 'lightconv_cuda_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }) 21 | -------------------------------------------------------------------------------- /fairseq/modules/logsumexp_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class LogSumExpMoE(torch.autograd.Function): 10 | """Standard LogSumExp forward pass, but use *posterior* for the backward. 11 | 12 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" 13 | (Shen et al., 2019) `_. 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, logp, posterior, dim=-1): 18 | ctx.save_for_backward(posterior) 19 | ctx.dim = dim 20 | return torch.logsumexp(logp, dim=dim) 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | posterior, = ctx.saved_tensors 25 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior 26 | return grad_logp, None, None 27 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import torch 8 | 9 | 10 | class ScalarBias(torch.autograd.Function): 11 | """ 12 | Adds a vector of scalars, used in self-attention mechanism to allow 13 | the model to optionally attend to this vector instead of the past 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, input, dim, bias_init): 18 | size = list(input.size()) 19 | size[dim] += 1 20 | output = input.new(*size).fill_(bias_init) 21 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 22 | ctx.dim = dim 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad): 27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 28 | 29 | 30 | def scalar_bias(input, dim, bias_init=0): 31 | return ScalarBias.apply(input, dim, bias_init) 32 | -------------------------------------------------------------------------------- /fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 10 | '''unfold T x B x C to T x B x C x K''' 11 | if kernel_size > 1: 12 | T, B, C = x.size() 13 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 14 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 15 | else: 16 | x = x.unsqueeze(3) 17 | return x 18 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer 11 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 12 | from fairseq.optim.bmuf import FairseqBMUF # noqa 13 | 14 | 15 | __all__ = [ 16 | 'FairseqOptimizer', 17 | 'FP16Optimizer', 18 | 'MemoryEfficientFP16Optimizer', 19 | ] 20 | 21 | 22 | build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( 23 | '--optimizer', 24 | base_class=FairseqOptimizer, 25 | default='nag', 26 | ) 27 | 28 | 29 | # automatically import any Python files in the optim/ directory 30 | for file in os.listdir(os.path.dirname(__file__)): 31 | if file.endswith('.py') and not file.startswith('_'): 32 | module = file[:file.find('.py')] 33 | importlib.import_module('fairseq.optim.' + module) 34 | -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/adadelta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/adadelta.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/adafactor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/adafactor.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/adagrad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/adagrad.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/adam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/adam.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/adamax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/adamax.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/bmuf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/bmuf.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/fairseq_optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/fairseq_optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/fp16_optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/fp16_optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/nag.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/nag.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/__pycache__/sgd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/__pycache__/sgd.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import FairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer('adagrad') 12 | class Adagrad(FairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 22 | help='weight decay') 23 | # fmt: on 24 | 25 | @property 26 | def optimizer_config(self): 27 | """ 28 | Return a kwarg dictionary that will be used to override optimizer 29 | args stored in checkpoints. This allows us to load a checkpoint and 30 | resume training using a different set of optimizer args, e.g., with a 31 | different learning rate. 32 | """ 33 | return { 34 | 'lr': self.args.lr[0], 35 | 'weight_decay': self.args.weight_decay, 36 | } 37 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler 11 | 12 | 13 | build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( 14 | '--lr-scheduler', 15 | base_class=FairseqLRScheduler, 16 | default='fixed', 17 | ) 18 | 19 | # automatically import any Python files in the optim/lr_scheduler/ directory 20 | for file in os.listdir(os.path.dirname(__file__)): 21 | if file.endswith('.py') and not file.startswith('_'): 22 | module = file[:file.find('.py')] 23 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 24 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .. import FairseqOptimizer 7 | 8 | 9 | class FairseqLRScheduler(object): 10 | 11 | def __init__(self, args, optimizer): 12 | super().__init__() 13 | if not isinstance(optimizer, FairseqOptimizer): 14 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 15 | self.args = args 16 | self.optimizer = optimizer 17 | self.best = None 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add arguments to the parser for this LR scheduler.""" 22 | pass 23 | 24 | def state_dict(self): 25 | """Return the LR scheduler state dict.""" 26 | return {'best': self.best} 27 | 28 | def load_state_dict(self, state_dict): 29 | """Load an LR scheduler state dict.""" 30 | self.best = state_dict['best'] 31 | 32 | def step(self, epoch, val_loss=None): 33 | """Update the learning rate at the end of the given epoch.""" 34 | if val_loss is not None: 35 | if self.best is None: 36 | self.best = val_loss 37 | else: 38 | self.best = min(self.best, val_loss) 39 | 40 | def step_update(self, num_updates): 41 | """Update the learning rate after each update.""" 42 | return self.optimizer.get_lr() 43 | -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import FairseqOptimizer, register_optimizer 9 | 10 | 11 | @register_optimizer('sgd') 12 | class SGD(FairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M', 22 | help='momentum factor') 23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 24 | help='weight decay') 25 | # fmt: on 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | 'lr': self.args.lr[0], 37 | 'momentum': self.args.momentum, 38 | 'weight_decay': self.args.weight_decay, 39 | } 40 | -------------------------------------------------------------------------------- /fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import pdb 9 | import sys 10 | 11 | 12 | __all__ = ['set_trace'] 13 | 14 | 15 | _stdin = [None] 16 | _stdin_lock = multiprocessing.Lock() 17 | try: 18 | _stdin_fd = sys.stdin.fileno() 19 | except Exception: 20 | _stdin_fd = None 21 | 22 | 23 | class MultiprocessingPdb(pdb.Pdb): 24 | """A Pdb wrapper that works in a multiprocessing environment. 25 | 26 | Usage: `from fairseq import pdb; pdb.set_trace()` 27 | """ 28 | 29 | def __init__(self): 30 | pdb.Pdb.__init__(self, nosigint=True) 31 | 32 | def _cmdloop(self): 33 | stdin_bak = sys.stdin 34 | with _stdin_lock: 35 | try: 36 | if _stdin_fd is not None: 37 | if not _stdin[0]: 38 | _stdin[0] = os.fdopen(_stdin_fd) 39 | sys.stdin = _stdin[0] 40 | self.cmdloop() 41 | finally: 42 | sys.stdin = stdin_bak 43 | 44 | 45 | def set_trace(): 46 | pdb = MultiprocessingPdb() 47 | pdb.set_trace(sys._getframe().f_back) 48 | -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/MLMetKE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/MLMetKE.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/audio_pretraining.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/audio_pretraining.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/cross_lingual_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/fairseq_task.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/fairseq_task.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/language_modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/language_modeling.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/legacy_masked_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/masked_lm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/masked_lm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/multilingual_translation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/multilingual_translation.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/semisupervised_translation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/semisupervised_translation.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/sentence_prediction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/sentence_prediction.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/sentence_ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/sentence_ranking.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/tacred_task.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/tacred_task.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/translation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/translation.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/__pycache__/translation_moe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq/tasks/__pycache__/translation_moe.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary 7 | from fairseq.tasks.translation import TranslationTask 8 | 9 | from . import register_task 10 | 11 | 12 | @register_task("translation_from_pretrained_xlm") 13 | class TranslationFromPretrainedXLMTask(TranslationTask): 14 | """ 15 | Same as TranslationTask except use the MaskedLMDictionary class so that 16 | we can load data that was binarized with the MaskedLMDictionary class. 17 | 18 | This task should be used for the entire training pipeline when we want to 19 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 20 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 21 | of that trained model. 22 | """ 23 | 24 | @classmethod 25 | def load_dictionary(cls, filename): 26 | """Load the masked LM dictionary from the filename 27 | 28 | Args: 29 | filename (str): the filename 30 | """ 31 | return MaskedLMDictionary.load(filename) 32 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | SPACE_NORMALIZER = re.compile(r"\s+") 9 | 10 | 11 | def tokenize_line(line): 12 | line = SPACE_NORMALIZER.sub(" ", line) 13 | line = line.strip() 14 | return line.split() 15 | -------------------------------------------------------------------------------- /fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /fairseq_cli/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq_cli/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq_cli/__pycache__/preprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq_cli/__pycache__/preprocess.cpython-36.pyc -------------------------------------------------------------------------------- /fairseq_cli/eval_lm.py: -------------------------------------------------------------------------------- 1 | ../eval_lm.py -------------------------------------------------------------------------------- /fairseq_cli/generate.py: -------------------------------------------------------------------------------- 1 | ../generate.py -------------------------------------------------------------------------------- /fairseq_cli/interactive.py: -------------------------------------------------------------------------------- 1 | ../interactive.py -------------------------------------------------------------------------------- /fairseq_cli/preprocess.py: -------------------------------------------------------------------------------- 1 | ../preprocess.py -------------------------------------------------------------------------------- /fairseq_cli/score.py: -------------------------------------------------------------------------------- 1 | ../score.py -------------------------------------------------------------------------------- /fairseq_cli/setup.py: -------------------------------------------------------------------------------- 1 | ../setup.py -------------------------------------------------------------------------------- /fairseq_cli/train.py: -------------------------------------------------------------------------------- 1 | ../train.py -------------------------------------------------------------------------------- /fairseq_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/fairseq_logo.png -------------------------------------------------------------------------------- /graphvite/asset/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/graph.png -------------------------------------------------------------------------------- /graphvite/asset/knowledge_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/knowledge_graph.png -------------------------------------------------------------------------------- /graphvite/asset/logo/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/logo/favicon.ico -------------------------------------------------------------------------------- /graphvite/asset/logo/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/logo/logo.png -------------------------------------------------------------------------------- /graphvite/asset/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/visualization.png -------------------------------------------------------------------------------- /graphvite/asset/visualization/imagenet_hierarchy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/visualization/imagenet_hierarchy.gif -------------------------------------------------------------------------------- /graphvite/asset/visualization/mnist_3d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/graphvite/asset/visualization/mnist_3d.gif -------------------------------------------------------------------------------- /graphvite/cmake/FindGlog.cmake: -------------------------------------------------------------------------------- 1 | # - Try to find Glog 2 | # 3 | # The following variables are optionally searched for defaults 4 | # GLOG_ROOT_DIR: Base directory where all GLOG components are found 5 | # 6 | # The following are set after configuration is done: 7 | # GLOG_FOUND 8 | # GLOG_INCLUDE_DIRS 9 | # GLOG_LIBRARIES 10 | 11 | include(FindPackageHandleStandardArgs) 12 | 13 | set(GLOG_ROOT_DIR "" CACHE PATH "Folder contains Google glog") 14 | 15 | if(WIN32) 16 | find_path(GLOG_INCLUDE_DIR glog/logging.h 17 | PATHS ${GLOG_ROOT_DIR}/src/windows) 18 | else() 19 | find_path(GLOG_INCLUDE_DIR glog/logging.h 20 | PATHS ${GLOG_ROOT_DIR}) 21 | endif() 22 | 23 | if(MSVC) 24 | find_library(GLOG_LIBRARY_RELEASE libglog_static 25 | PATHS ${GLOG_ROOT_DIR} 26 | PATH_SUFFIXES Release) 27 | 28 | find_library(GLOG_LIBRARY_DEBUG libglog_static 29 | PATHS ${GLOG_ROOT_DIR} 30 | PATH_SUFFIXES Debug) 31 | 32 | set(GLOG_LIBRARY optimized ${GLOG_LIBRARY_RELEASE} debug ${GLOG_LIBRARY_DEBUG}) 33 | else() 34 | find_library(GLOG_LIBRARY glog 35 | PATHS ${GLOG_ROOT_DIR} 36 | PATH_SUFFIXES lib lib64) 37 | endif() 38 | 39 | find_package_handle_standard_args(Glog DEFAULT_MSG GLOG_INCLUDE_DIR GLOG_LIBRARY) 40 | 41 | if(GLOG_FOUND) 42 | set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) 43 | set(GLOG_LIBRARIES ${GLOG_LIBRARY}) 44 | message(STATUS "Found glog (include: ${GLOG_INCLUDE_DIR}, library: ${GLOG_LIBRARY})") 45 | mark_as_advanced(GLOG_ROOT_DIR GLOG_LIBRARY_RELEASE GLOG_LIBRARY_DEBUG 46 | GLOG_LIBRARY GLOG_INCLUDE_DIR) 47 | endif() -------------------------------------------------------------------------------- /graphvite/conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | cxx_compiler_version: 2 | - 5.4 3 | 4 | python: 5 | - 2.7 6 | - 3.6 7 | - 3.7 8 | 9 | numpy: 10 | - 1.11 11 | 12 | cudatoolkit: 13 | - 9.2 14 | - 10.0 15 | 16 | pin_run_as_build: 17 | cudatoolkit: 18 | max_pin: x.x -------------------------------------------------------------------------------- /graphvite/conda/graphvite-mini/build.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | mkdir -p build 4 | 5 | cd build 6 | cmake .. -DALL_ARCH=True 7 | make 8 | cd - 9 | 10 | cd python 11 | $PYTHON setup.py install 12 | cd - -------------------------------------------------------------------------------- /graphvite/conda/graphvite-mini/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: graphvite-mini 3 | version: 0.2.1 4 | 5 | source: 6 | path: ../.. 7 | 8 | requirements: 9 | build: 10 | # cmake 11 | - cmake >=3.12 12 | - {{ compiler("cxx") }} 13 | - glog 14 | - gflags 15 | - cudatoolkit {{ cudatoolkit }} 16 | - python {{ python }} 17 | - pybind11 18 | host: 19 | # make 20 | - glog 21 | - gflags 22 | - cudatoolkit {{ cudatoolkit }} 23 | - python {{ python }} 24 | - pybind11 25 | - numpy {{ numpy }} 26 | - mkl >=2018 27 | # setup 28 | - pyyaml 29 | - easydict 30 | - six 31 | run: 32 | - glog 33 | - gflags 34 | - cudatoolkit 35 | - python {{ python }} 36 | - mkl >=2018 37 | - numpy >=1.11 38 | - pyyaml 39 | - easydict 40 | - six 41 | - future 42 | - psutil 43 | 44 | build: 45 | string: 46 | "py{{ python|replace('.', '') }}\ 47 | cuda{{ cudatoolkit|replace('.', '') }}\ 48 | h{{ environ.get('GIT_FULL_HASH')|string|truncate(7, True, '', 0) }}" 49 | 50 | test: 51 | imports: 52 | - graphvite 53 | 54 | about: 55 | home: https://graphvite.io 56 | license: Apache-2.0 57 | summary: "A general and high-performance graph embedding system for various applications" -------------------------------------------------------------------------------- /graphvite/conda/graphvite/build.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | mkdir -p build 4 | 5 | cd build 6 | cmake .. -DALL_ARCH=True 7 | make 8 | cd - 9 | 10 | cd python 11 | $PYTHON setup.py install 12 | cd - -------------------------------------------------------------------------------- /graphvite/conda/graphvite/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: graphvite 3 | version: 0.2.1 4 | 5 | source: 6 | path: ../.. 7 | 8 | requirements: 9 | build: 10 | # cmake 11 | - cmake >=3.12 12 | - {{ compiler("cxx") }} 13 | - glog 14 | - gflags 15 | - cudatoolkit {{ cudatoolkit }} 16 | - python {{ python }} 17 | - pybind11 18 | host: 19 | # make 20 | - glog 21 | - gflags 22 | - cudatoolkit {{ cudatoolkit }} 23 | - python {{ python }} 24 | - pybind11 25 | - numpy {{ numpy }} 26 | - mkl >=2018 27 | # setup 28 | - pyyaml 29 | - easydict 30 | - six 31 | run: 32 | - glog 33 | - gflags 34 | - cudatoolkit 35 | - python {{ python }} 36 | - mkl >=2018 37 | - numpy >=1.11 38 | - pyyaml 39 | - easydict 40 | - six 41 | - future 42 | - imageio 43 | - psutil 44 | - scipy 45 | - matplotlib 46 | - pytorch 47 | - torchvision 48 | - nltk 49 | 50 | build: 51 | string: 52 | "py{{ python|replace('.', '') }}\ 53 | cuda{{ cudatoolkit|replace('.', '') }}\ 54 | h{{ environ.get('GIT_FULL_HASH')|string|truncate(7, True, '', 0) }}" 55 | 56 | test: 57 | imports: 58 | - graphvite 59 | 60 | about: 61 | home: https://graphvite.io 62 | license: Apache-2.0 63 | summary: "A general and high-performance graph embedding system for various applications" -------------------------------------------------------------------------------- /graphvite/conda/requirements.txt: -------------------------------------------------------------------------------- 1 | # cmake 2 | cmake >=3.12 3 | gxx_linux-64 >=5.4 4 | glog 5 | gflags 6 | cudatoolkit >=9.2 7 | python 8 | pybind11 9 | 10 | # make 11 | mkl >=2018 12 | 13 | # run 14 | numpy >=1.11 15 | pyyaml 16 | conda-forge::easydict 17 | six 18 | future 19 | imageio 20 | psutil 21 | scipy 22 | matplotlib 23 | pytorch 24 | torchvision 25 | nltk -------------------------------------------------------------------------------- /graphvite/config/demo/math.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [0] 6 | cpu_per_gpu: 8 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 5.0e-3 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 8 19 | batch_size: 100000 20 | episode_size: 100 21 | 22 | train: 23 | model: RotatE 24 | num_epoch: 2000 25 | margin: 9 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | target: tail 38 | 39 | save: 40 | file_name: rotate_math.pkl -------------------------------------------------------------------------------- /graphvite/config/demo/quick_start.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [0] 6 | cpu_per_gpu: 8 7 | dim: 128 8 | 9 | format: 10 | delimiters: " \t\r\n" 11 | comment: "#" 12 | 13 | graph: 14 | file_name: 15 | as_undirected: true 16 | 17 | build: 18 | optimizer: 19 | type: SGD 20 | lr: 0.025 21 | weight_decay: 0.005 22 | num_partition: auto 23 | num_negative: 1 24 | batch_size: 100000 25 | episode_size: 500 26 | 27 | train: 28 | model: LINE 29 | num_epoch: 2000 30 | negative_weight: 5 31 | augmentation_step: 2 32 | random_walk_length: 40 33 | random_walk_batch_size: 100 34 | log_frequency: 1000 35 | 36 | evaluate: 37 | - task: link prediction 38 | file_name: 39 | filter_file: 40 | - task: node classification 41 | file_name: 42 | portions: [0.2] 43 | times: 1 44 | 45 | save: 46 | file_name: line_blogcatalog.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/deepwalk_flickr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 1000 22 | 23 | train: 24 | # here the best setting uses no augmentation 25 | # in this case, DeepWalk is equal to LINE 26 | model: DeepWalk 27 | num_epoch: 2000 28 | negative_weight: 5 29 | augmentation_step: 1 30 | random_walk_length: 40 31 | random_walk_batch_size: 100 32 | log_frequency: 1000 33 | 34 | evaluate: 35 | task: node classification 36 | file_name: 37 | portions: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 38 | times: 5 39 | 40 | save: 41 | file_name: deepwalk_flickr.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/deepwalk_friendster-small.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 3500 22 | 23 | train: 24 | # here the best setting uses no augmentation 25 | # in this case, DeepWalk is equal to LINE 26 | model: DeepWalk 27 | num_epoch: 2000 28 | negative_weight: 5 29 | augmentation_step: 1 30 | random_walk_length: 40 31 | random_walk_batch_size: 100 32 | log_frequency: 1000 33 | 34 | evaluate: 35 | task: node classification 36 | file_name: 37 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 38 | times: 5 39 | 40 | save: 41 | file_name: deepwalk_friendster-small.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/deepwalk_friendster.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 96 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 2500 22 | 23 | train: 24 | model: DeepWalk 25 | num_epoch: 2000 26 | negative_weight: 5 27 | augmentation_step: 2 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: node classification 34 | file_name: 35 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 36 | times: 5 37 | 38 | save: 39 | file_name: deepwalk_friendster.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/deepwalk_hyperlink-pld.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 5000 22 | 23 | train: 24 | model: DeepWalk 25 | num_epoch: 2000 26 | negative_weight: 5 27 | augmentation_step: 2 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_file: 36 | 37 | save: 38 | file_name: deepwalk_hyperlink-pld.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/deepwalk_youtube.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 500 22 | 23 | train: 24 | model: DeepWalk 25 | num_epoch: 4000 26 | negative_weight: 5 27 | augmentation_step: 5 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: node classification 34 | file_name: 35 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 36 | times: 5 37 | 38 | save: 39 | file_name: deepwalk_youtube.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/line_flickr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 1000 22 | 23 | train: 24 | model: LINE 25 | num_epoch: 2000 26 | negative_weight: 5 27 | augmentation_step: 1 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: node classification 34 | file_name: 35 | portions: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 36 | times: 5 37 | 38 | save: 39 | file_name: line_flickr.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/line_friendster-small.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 3500 22 | 23 | train: 24 | model: LINE 25 | num_epoch: 2000 26 | negative_weight: 5 27 | augmentation_step: 1 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: node classification 34 | file_name: 35 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 36 | times: 5 37 | 38 | save: 39 | file_name: line_friendster-small.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/line_friendster.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 96 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 2500 22 | 23 | train: 24 | model: LINE 25 | num_epoch: 2000 26 | negative_weight: 5 27 | augmentation_step: 2 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: node classification 34 | file_name: 35 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 36 | times: 5 37 | 38 | save: 39 | file_name: line_friendster.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/line_hyperlink-pld.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 5000 22 | 23 | train: 24 | model: LINE 25 | num_epoch: 2000 26 | negative_weight: 5 27 | augmentation_step: 2 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_file: 36 | 37 | save: 38 | file_name: line_hyperlink-pld.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/line_youtube.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 500 22 | 23 | train: 24 | model: LINE 25 | num_epoch: 4000 26 | negative_weight: 5 27 | augmentation_step: 5 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | log_frequency: 1000 31 | 32 | evaluate: 33 | task: node classification 34 | file_name: 35 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 36 | times: 5 37 | 38 | save: 39 | file_name: line_youtube.pkl -------------------------------------------------------------------------------- /graphvite/config/graph/node2vec_youtube.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | as_undirected: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 0.025 17 | weight_decay: 0.005 18 | num_partition: auto 19 | num_negative: 1 20 | batch_size: 100000 21 | episode_size: 500 22 | 23 | train: 24 | model: node2vec 25 | num_epoch: 4000 26 | negative_weight: 5 27 | augmentation_step: 5 28 | random_walk_length: 40 29 | random_walk_batch_size: 100 30 | p: 4 31 | q: 2 32 | log_frequency: 1000 33 | 34 | evaluate: 35 | task: node classification 36 | file_name: 37 | portions: [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10] 38 | times: 5 39 | 40 | save: 41 | file_name: node2vec_youtube.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/complex_fb15k-237.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: ComplEx 24 | num_epoch: 1000 25 | l3_regularization: 5.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: complex_fb15k-237.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/complex_fb15k.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-4 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: ComplEx 24 | num_epoch: 1000 25 | l3_regularization: 1.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: complex_fb15k.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/complex_wikidata5m.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | normalization: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 1.0e-1 17 | weight_decay: 0 18 | num_partition: auto 19 | num_negative: 64 20 | batch_size: 100000 21 | episode_size: 200 22 | 23 | train: 24 | model: ComplEx 25 | num_epoch: 1000 26 | l3_regularization: 2.0e-3 27 | sample_batch_size: 2000 28 | adversarial_temperature: 0.2 29 | relation_lr_multiplier: 1.0e-3 30 | log_frequency: 500 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_files: 36 | - 37 | - 38 | - 39 | # fast_mode: 1000 40 | 41 | save: 42 | file_name: complex_wikidata5m.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/complex_wn18.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 1.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: ComplEx 24 | num_epoch: 4000 25 | l3_regularization: 5.0e-5 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: complex_wn18.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/complex_wn18rr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 1.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: ComplEx 24 | num_epoch: 6000 25 | l3_regularization: 5.0e-6 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: complex_wn18rr.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/distmult_fb15k-237.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: DistMult 24 | num_epoch: 1000 25 | l3_regularization: 5.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: distmult_fb15k-237.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/distmult_fb15k.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 5.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: DistMult 24 | num_epoch: 1000 25 | l3_regularization: 1.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: distmult_fb15k.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/distmult_wikidata5m.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | normalization: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 1.0e-1 17 | weight_decay: 0 18 | num_partition: auto 19 | num_negative: 64 20 | batch_size: 100000 21 | episode_size: 200 22 | 23 | train: 24 | model: DistMult 25 | num_epoch: 2000 26 | l3_regularization: 2.0e-3 27 | sample_batch_size: 2000 28 | adversarial_temperature: 2 29 | relation_lr_multiplier: 1.0e-4 30 | log_frequency: 500 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_files: 36 | - 37 | - 38 | - 39 | # fast_mode: 1000 40 | 41 | save: 42 | file_name: distmult_wikidata5m.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/distmult_wn18.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 1.0e-4 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: DistMult 24 | num_epoch: 4000 25 | l3_regularization: 1.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: distmult_wn18.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/distmult_wn18rr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: DistMult 24 | num_epoch: 6000 25 | l3_regularization: 1.0e-2 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: distmult_wn18rr.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/rotate_fb15k-237.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-6 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: RotatE 24 | num_epoch: 1000 25 | margin: 9 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: rotate_fb15k-237.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/rotate_fb15k.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-4 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: RotatE 24 | num_epoch: 1000 25 | margin: 24 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: rotate_fb15k.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/rotate_wikidata5m.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | normalization: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 1.0e-2 17 | weight_decay: 0 18 | num_partition: auto 19 | num_negative: 64 20 | batch_size: 100000 21 | episode_size: 200 22 | 23 | train: 24 | model: RotatE 25 | num_epoch: 1000 26 | margin: 6 27 | sample_batch_size: 2000 28 | adversarial_temperature: 0.2 29 | relation_lr_multiplier: 1 30 | log_frequency: 500 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_files: 36 | - 37 | - 38 | - 39 | # fast_mode: 1000 40 | 41 | save: 42 | file_name: rotate_wikidata5m.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/rotate_wn18.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 5.0e-6 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: RotatE 24 | num_epoch: 4000 25 | margin: 9 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: rotate_wn18.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/rotate_wn18rr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 5.0e-6 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: RotatE 24 | num_epoch: 6000 25 | margin: 6 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: rotate_wn18rr.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/simple_fb15k-237.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: SimplE 24 | num_epoch: 1000 25 | l3_regularization: 5.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: simple_fb15k-237.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/simple_fb15k.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 2048 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: SimplE 24 | num_epoch: 1000 25 | l3_regularization: 1.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: simple_fb15k.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/simple_wikidata5m.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | normalization: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 1.0 17 | weight_decay: 0 18 | num_partition: auto 19 | num_negative: 64 20 | batch_size: 100000 21 | episode_size: 200 22 | 23 | train: 24 | model: SimplE 25 | num_epoch: 2000 26 | l3_regularization: 2.0e-3 27 | sample_batch_size: 2000 28 | adversarial_temperature: 2 29 | relation_lr_multiplier: 1.0e-4 30 | log_frequency: 500 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_files: 36 | - 37 | - 38 | - 39 | # fast_mode: 1000 40 | 41 | save: 42 | file_name: simple_wikidata5m.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/simple_wn18.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 5.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: SimplE 24 | num_epoch: 4000 25 | l3_regularization: 2.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: simple_wn18.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/simple_wn18rr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 1.0e-4 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: SimplE 24 | num_epoch: 6000 25 | l3_regularization: 2.0e-3 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: simple_wn18rr.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/transe_fb15k-237.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 2.0e-6 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: TransE 24 | num_epoch: 1000 25 | margin: 9 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: transe_fb15k-237.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/transe_fb15k.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 1024 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 1.0e-5 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: TransE 24 | num_epoch: 1000 25 | margin: 24 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: transe_fb15k.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/transe_wikidata5m.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | normalization: true 12 | 13 | build: 14 | optimizer: 15 | type: SGD 16 | lr: 1.0e-3 17 | weight_decay: 0 18 | num_partition: auto 19 | num_negative: 64 20 | batch_size: 100000 21 | episode_size: 200 22 | 23 | train: 24 | model: TransE 25 | num_epoch: 1000 26 | margin: 12 27 | sample_batch_size: 2000 28 | adversarial_temperature: 0.5 29 | relation_lr_multiplier: 1.0e-2 30 | log_frequency: 500 31 | 32 | evaluate: 33 | task: link prediction 34 | file_name: 35 | filter_files: 36 | - 37 | - 38 | - 39 | # fast_mode: 1000 40 | 41 | save: 42 | file_name: transe_wikidata5m.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/transe_wn18.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 5.0e-6 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: TransE 24 | num_epoch: 4000 25 | margin: 12 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: transe_wn18.pkl -------------------------------------------------------------------------------- /graphvite/config/knowledge_graph/transe_wn18rr.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | knowledge graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 512 8 | 9 | graph: 10 | file_name: 11 | 12 | build: 13 | optimizer: 14 | type: Adam 15 | lr: 1.0e-6 16 | weight_decay: 0 17 | num_partition: auto 18 | num_negative: 64 19 | batch_size: 100000 20 | episode_size: 1 21 | 22 | train: 23 | model: TransE 24 | num_epoch: 6000 25 | margin: 6 26 | sample_batch_size: 2000 27 | adversarial_temperature: 2 28 | log_frequency: 100 29 | 30 | evaluate: 31 | task: link prediction 32 | file_name: 33 | filter_files: 34 | - 35 | - 36 | - 37 | # fast_mode: 3000 38 | 39 | save: 40 | file_name: transe_wn18rr.pkl -------------------------------------------------------------------------------- /graphvite/config/visualization/largevis_imagenet.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | visualization 3 | 4 | resource: 5 | gpus: [0] 6 | cpu_per_gpu: auto 7 | dim: 2 8 | 9 | graph: 10 | vectors: 11 | num_neighbor: 200 12 | perplexity: 50 13 | 14 | build: 15 | optimizer: 16 | type: Adam 17 | lr: 0.5 18 | weight_decay: 1.0e-5 19 | num_partition: auto 20 | num_negative: 5 21 | batch_size: 100000 22 | episode_size: 200 23 | 24 | train: 25 | model: LargeVis 26 | num_epoch: 50 27 | negative_weight: 3 28 | log_frequency: 1000 29 | 30 | evaluate: 31 | task: hierarchy 32 | file_name: 33 | target: english_setter 34 | save_file: imagenet_hierarchy.gif 35 | 36 | save: 37 | file_name: largevis_imagenet_2d.pkl -------------------------------------------------------------------------------- /graphvite/config/visualization/largevis_mnist_2d.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | visualization 3 | 4 | resource: 5 | gpus: [0] 6 | cpu_per_gpu: auto 7 | dim: 2 8 | 9 | graph: 10 | vectors: 11 | num_neighbor: 200 12 | perplexity: 20 13 | 14 | build: 15 | optimizer: 16 | type: Adam 17 | lr: 0.5 18 | weight_decay: 1.0e-5 19 | num_partition: auto 20 | num_negative: 5 21 | batch_size: 100000 22 | episode_size: 200 23 | 24 | train: 25 | model: LargeVis 26 | num_epoch: 50 27 | negative_weight: 3 28 | log_frequency: 1000 29 | 30 | evaluate: 31 | task: visualization 32 | Y: 33 | save_file: mnist_2d.png 34 | 35 | save: 36 | file_name: largevis_mnist_2d.pkl -------------------------------------------------------------------------------- /graphvite/config/visualization/largevis_mnist_3d.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | visualization 3 | 4 | resource: 5 | gpus: [0] 6 | cpu_per_gpu: auto 7 | dim: 3 8 | 9 | graph: 10 | vectors: 11 | num_neighbor: 200 12 | perplexity: 20 13 | 14 | build: 15 | optimizer: 16 | type: Adam 17 | lr: 0.5 18 | weight_decay: 1.0e-5 19 | num_partition: auto 20 | num_negative: 5 21 | batch_size: 100000 22 | episode_size: 200 23 | 24 | train: 25 | model: LargeVis 26 | num_epoch: 50 27 | negative_weight: 3 28 | log_frequency: 1000 29 | 30 | evaluate: 31 | task: animation 32 | Y: 33 | save_file: mnist_3d.gif 34 | 35 | save: 36 | file_name: largevis_mnist_3d.pkl -------------------------------------------------------------------------------- /graphvite/config/word_graph/line_wikipedia.yaml: -------------------------------------------------------------------------------- 1 | application: 2 | word graph 3 | 4 | resource: 5 | gpus: [] 6 | cpu_per_gpu: auto 7 | dim: 128 8 | 9 | graph: 10 | file_name: 11 | window: 5 12 | min_count: 5 13 | 14 | build: 15 | optimizer: 16 | type: SGD 17 | lr: 0.025 18 | weight_decay: 0.005 19 | num_partition: auto 20 | num_negative: 1 21 | batch_size: 100000 22 | episode_size: 1000 23 | 24 | train: 25 | model: LINE 26 | num_epoch: 80 27 | negative_weight: 5 28 | augmentation_step: 1 29 | random_walk_length: 40 30 | random_walk_batch_size: 100 31 | log_frequency: 1000 32 | 33 | save: 34 | file_name: line_wikipedia.pkl -------------------------------------------------------------------------------- /graphvite/doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /graphvite/doc/source/api/application.rst: -------------------------------------------------------------------------------- 1 | graphvite.application 2 | ===================== 3 | 4 | .. automodule:: graphvite.application 5 | :members: 6 | :inherited-members: 7 | -------------------------------------------------------------------------------- /graphvite/doc/source/api/dataset.rst: -------------------------------------------------------------------------------- 1 | graphvite.dataset 2 | ================= 3 | 4 | .. automodule:: graphvite.dataset 5 | :members: -------------------------------------------------------------------------------- /graphvite/doc/source/api/graph.rst: -------------------------------------------------------------------------------- 1 | graphvite.graph 2 | =============== 3 | 4 | .. automodule:: graphvite.graph 5 | :members: -------------------------------------------------------------------------------- /graphvite/doc/source/api/optimizer.rst: -------------------------------------------------------------------------------- 1 | graphvite.optimizer 2 | =================== 3 | 4 | .. automodule:: graphvite.optimizer 5 | :members: 6 | -------------------------------------------------------------------------------- /graphvite/doc/source/api/solver.rst: -------------------------------------------------------------------------------- 1 | graphvite.solver 2 | ================ 3 | 4 | .. automodule:: graphvite.solver 5 | :members: -------------------------------------------------------------------------------- /graphvite/doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. GraphVite documentation master file, created by 2 | sphinx-quickstart on Wed May 29 18:13:45 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | GraphVite - graph embedding at high speed and large scale 7 | ========================================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Get Started 12 | 13 | Introduction 14 | install 15 | quick_start 16 | overview 17 | benchmark 18 | pretrained_model 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | :caption: User Guide 23 | 24 | user/command_line 25 | user/configuration 26 | user/format 27 | user/python 28 | user/auto 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | :caption: Developer Guide 33 | 34 | developer/framework 35 | developer/model 36 | developer/routine 37 | developer/solver 38 | 39 | .. toctree:: 40 | :maxdepth: 1 41 | :caption: Package Reference 42 | 43 | Application 44 | Graph 45 | Solver 46 | Optimizer 47 | Dataset 48 | 49 | .. toctree:: 50 | :maxdepth: 1 51 | :caption: FAQ 52 | 53 | FAQ 54 | 55 | Indices and tables 56 | ================== 57 | 58 | * :ref:`genindex` 59 | * :ref:`search` -------------------------------------------------------------------------------- /graphvite/doc/source/user/auto.rst: -------------------------------------------------------------------------------- 1 | Magic of Auto 2 | ============= 3 | 4 | Hyperparameter tuning is usually painful for machine learning practioners. In order 5 | to help users focus on the most important part, GraphVite provides an auto deduction 6 | for many hyperparameters. Generally, auto deduction will maximize the speed of the 7 | system, while keep the performance loss as small as possible. 8 | 9 | To invoke auto deduction, we can simply leave hyperparameters to their default 10 | values. An explicit way is to use ``auto`` in configuration files, or value 11 | ``gv.auto`` in Python. 12 | 13 | Here lists hyperparameters that support auto deduction. 14 | 15 | .. code-block:: yaml 16 | 17 | resource: 18 | gpus: [] 19 | gpu_memory_limit: auto 20 | cpu_per_gpu: auto 21 | 22 | build: 23 | optimizer: auto 24 | num_partition: auto 25 | episode_size: auto 26 | 27 | train: 28 | # for node embedding 29 | augmentation_step: auto 30 | 31 | .. note:: 32 | The auto value for ``gpus`` is an empty list. -------------------------------------------------------------------------------- /graphvite/external/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /graphvite/include/util/common.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2019 MilaGraph. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | * @author Zhaocheng Zhu 17 | */ 18 | 19 | #pragma once 20 | 21 | #include "io.h" 22 | #include "math.h" 23 | 24 | namespace graphvite { 25 | 26 | #define DEPRECATED(reason) __attribute__ ((deprecated(reason))) 27 | 28 | const float kEpsilon = 1e-15; 29 | const int kAuto = 0; 30 | const size_t kMaxLineLength = 1 << 22; 31 | 32 | constexpr size_t KiB(size_t x) { 33 | return x << 10; 34 | } 35 | 36 | constexpr size_t MiB(size_t x) { 37 | return x << 20; 38 | } 39 | 40 | constexpr size_t GiB(size_t x) { 41 | return x << 30; 42 | } 43 | 44 | } // namespace graphvite -------------------------------------------------------------------------------- /graphvite/include/util/debug.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2019 MilaGraph. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | * @author Zhaocheng Zhu 17 | */ 18 | 19 | #pragma once 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | namespace graphvite { 26 | 27 | #define CUDA_CHECK(error) CudaCheck((error), __FILE__, __LINE__) 28 | #define CURAND_CHECK(error) CurandCheck((error), __FILE__, __LINE__) 29 | 30 | inline void CudaCheck(cudaError_t error, const char *file_name, int line) { 31 | CHECK(error == cudaSuccess) 32 | << "CUDA error " << cudaGetErrorString(error) << " at " << file_name << ":" << line; 33 | } 34 | 35 | inline void CurandCheck(curandStatus_t error, const char *file_name, int line) { 36 | CHECK(error == CURAND_STATUS_SUCCESS) 37 | << "CURAND error " << error << " at " << file_name << ":" << line; 38 | } 39 | 40 | } // namespace graphvite -------------------------------------------------------------------------------- /graphvite/python/graphvite/application/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MilaGraph. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Zhaocheng Zhu 16 | 17 | """Application module of GraphVite""" 18 | from __future__ import absolute_import 19 | 20 | from .application import Application, \ 21 | GraphApplication, WordGraphApplication, KnowledgeGraphApplication, VisualizationApplication 22 | 23 | __all__ = [ 24 | "Application", 25 | "GraphApplication", "WordGraphApplication", "KnowledgeGraphApplication", "VisualizationApplication" 26 | ] -------------------------------------------------------------------------------- /graphvite/python/graphvite/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MilaGraph. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Zhaocheng Zhu 16 | 17 | """Graph module of GraphVite""" 18 | from __future__ import absolute_import 19 | 20 | import sys 21 | 22 | from . import lib, cfg 23 | from .helper import find_all_templates, make_helper_class 24 | 25 | module = sys.modules[__name__] 26 | 27 | for name in find_all_templates(lib.graph): 28 | module.__dict__[name] = make_helper_class(lib.graph, name, module, 29 | ["index_type"], [cfg.index_type]) 30 | 31 | __all__ = [ 32 | "Graph", "WordGraph", "KnowledgeGraph", "KNNGraph" 33 | ] -------------------------------------------------------------------------------- /graphvite/python/graphvite/solver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MilaGraph. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Zhaocheng Zhu 16 | 17 | """Solver module of GraphVite""" 18 | from __future__ import absolute_import 19 | 20 | import sys 21 | 22 | from . import lib, cfg 23 | from .helper import find_all_templates, make_helper_class 24 | 25 | module = sys.modules[__name__] 26 | 27 | for name in find_all_templates(lib.solver): 28 | module.__dict__[name] = make_helper_class(lib.solver, name, module, 29 | ["dim", "float_type", "index_type"], 30 | [None, cfg.float_type, cfg.index_type]) 31 | 32 | __all__ = [ 33 | "GraphSolver", "KnowledgeGraphSolver", "VisualizationSolver" 34 | ] -------------------------------------------------------------------------------- /graphvite/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (WIN32) 2 | add_library(graphvite graphvite.cu) 3 | else () 4 | add_library(graphvite SHARED graphvite.cu) 5 | set_target_properties(graphvite PROPERTIES 6 | CXX_VISIBILITY_PRESET "hidden" 7 | CUDA_VISIBILITY_PRESET "hidden" 8 | LINK_FLAGS "-flto -Wl,-rpath=$ORIGIN" 9 | OUTPUT_NAME graphvite) 10 | 11 | target_link_libraries(graphvite pthread curand glog.so) 12 | target_compile_options(graphvite PRIVATE "-Xcompiler=-fno-fat-lto-objects") # -flto 13 | endif () 14 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import functools 7 | 8 | from fairseq.hub_utils import BPEHubInterface as bpe # noqa 9 | from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa 10 | from fairseq.models import MODEL_REGISTRY 11 | 12 | 13 | dependencies = [ 14 | 'regex', 15 | 'requests', 16 | 'torch', 17 | ] 18 | 19 | 20 | for _model_type, _cls in MODEL_REGISTRY.items(): 21 | for model_name in _cls.hub_models().keys(): 22 | globals()[model_name] = functools.partial( 23 | _cls.from_pretrained, 24 | model_name, 25 | ) 26 | # to simplify the interface we only expose named models 27 | # globals()[_model_type] = _cls.from_pretrained 28 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compare_namespaces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script to compare two argparse.Namespace objects.""" 3 | 4 | from argparse import Namespace # noqa 5 | 6 | 7 | def main(): 8 | 9 | ns1 = eval(input('Namespace 1: ')) 10 | ns2 = eval(input('Namespace 2: ')) 11 | 12 | def keys(ns): 13 | ks = set() 14 | for k in dir(ns): 15 | if not k.startswith('_'): 16 | ks.add(k) 17 | return ks 18 | 19 | k1 = keys(ns1) 20 | k2 = keys(ns2) 21 | 22 | def print_keys(ks, ns1, ns2=None): 23 | for k in ks: 24 | if ns2 is None: 25 | print('{}\t{}'.format(k, getattr(ns1, k, None))) 26 | else: 27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None))) 28 | 29 | print('Keys unique to namespace 1:') 30 | print_keys(k1 - k2, ns1) 31 | print() 32 | 33 | print('Keys unique to namespace 2:') 34 | print_keys(k2 - k1, ns2) 35 | print() 36 | 37 | print('Overlapping keys with different values:') 38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')] 39 | print_keys(ks, ns1, ns2) 40 | print() 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) Facebook, Inc. and its affiliates. 2 | -- 3 | -- This source code is licensed under the MIT license found in the 4 | -- LICENSE file in the root directory of this source tree. 5 | -- 6 | -- Usage: convert_dictionary.lua 7 | require 'fairseq' 8 | require 'torch' 9 | require 'paths' 10 | 11 | if #arg < 1 then 12 | print('usage: convert_dictionary.lua ') 13 | os.exit(1) 14 | end 15 | if not paths.filep(arg[1]) then 16 | print('error: file does not exit: ' .. arg[1]) 17 | os.exit(1) 18 | end 19 | 20 | dict = torch.load(arg[1]) 21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 22 | assert(dst:match('.txt$')) 23 | 24 | f = io.open(dst, 'w') 25 | for idx, symbol in ipairs(dict.index_to_symbol) do 26 | if idx > dict.cutoff then 27 | break 28 | end 29 | f:write(symbol) 30 | f:write(' ') 31 | f:write(dict.index_to_freq[idx]) 32 | f:write('\n') 33 | end 34 | f:close() 35 | -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | from fairseq.data import data_utils, Dictionary, indexed_dataset 10 | 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser( 14 | description='writes text from binarized file to stdout') 15 | # fmt: off 16 | parser.add_argument('--dataset-impl', help='dataset implementation', 17 | choices=indexed_dataset.get_available_dataset_impl()) 18 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) 19 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 20 | # fmt: on 21 | 22 | return parser 23 | 24 | 25 | def main(): 26 | parser = get_parser() 27 | args = parser.parse_args() 28 | 29 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None 30 | dataset = data_utils.load_indexed_dataset( 31 | args.input, 32 | dictionary, 33 | dataset_impl=args.dataset_impl, 34 | default='lazy', 35 | ) 36 | 37 | for tensor_line in dataset: 38 | if dictionary is None: 39 | line = ' '.join([str(int(x)) for x in tensor_line]) 40 | else: 41 | line = dictionary.string(tensor_line) 42 | 43 | print(line) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/sacrebleu_pregen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | echo 'Cloning Moses github repository (for tokenization scripts)...' 15 | git clone https://github.com/moses-smt/mosesdecoder.git 16 | 17 | SCRIPTS=mosesdecoder/scripts 18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | perl $DETOKENIZER -l $TGTLANG \ 25 | | sed "s/ - /-/g" \ 26 | > $GEN.sorted.detok 27 | 28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 29 | -------------------------------------------------------------------------------- /scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", required=True, 18 | help="sentencepiece model to use for decoding") 19 | parser.add_argument("--input", required=True, help="input file to decode") 20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 21 | args = parser.parse_args() 22 | 23 | sp = spm.SentencePieceProcessor() 24 | sp.Load(args.model) 25 | 26 | if args.input_format == "piece": 27 | def decode(l): 28 | return "".join(sp.DecodePieces(l)) 29 | elif args.input_format == "id": 30 | def decode(l): 31 | return "".join(sp.DecodeIds(l)) 32 | else: 33 | raise NotImplementedError 34 | 35 | def tok2int(tok): 36 | # remap reference-side (represented as <>) to 0 37 | return int(tok) if tok != "<>" else 0 38 | 39 | with open(args.input, "r", encoding="utf-8") as h: 40 | for line in h: 41 | print(decode(list(map(tok2int, line.rstrip().split())))) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import sys 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | if __name__ == "__main__": 16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/tests/__init__.py -------------------------------------------------------------------------------- /tests/speech_recognition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-KEG/KEPLER/05304cc07cc4a904006ffe709688945d29725aac/tests/speech_recognition/__init__.py -------------------------------------------------------------------------------- /tests/speech_recognition/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion 8 | from .asr_test_base import CrossEntropyCriterionTestBase 9 | 10 | 11 | class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase): 12 | def setUp(self): 13 | self.criterion_cls = CrossEntropyWithAccCriterion 14 | super().setUp() 15 | 16 | def test_cross_entropy_all_correct(self): 17 | sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False) 18 | loss, sample_size, logging_output = self.criterion( 19 | self.model, sample, "sum", log_probs=True 20 | ) 21 | assert logging_output["correct"] == 20 22 | assert logging_output["total"] == 20 23 | assert logging_output["sample_size"] == 20 24 | assert logging_output["ntokens"] == 20 25 | 26 | def test_cross_entropy_all_wrong(self): 27 | sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False) 28 | loss, sample_size, logging_output = self.criterion( 29 | self.model, sample, "sum", log_probs=True 30 | ) 31 | assert logging_output["correct"] == 0 32 | assert logging_output["total"] == 20 33 | assert logging_output["sample_size"] == 20 34 | assert logging_output["ntokens"] == 20 35 | -------------------------------------------------------------------------------- /tests/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | from fairseq.data import iterators 9 | 10 | 11 | class TestIterators(unittest.TestCase): 12 | 13 | def test_counting_iterator(self): 14 | x = list(range(10)) 15 | itr = iterators.CountingIterator(x) 16 | self.assertTrue(itr.has_next()) 17 | self.assertEqual(next(itr), 0) 18 | self.assertEqual(next(itr), 1) 19 | itr.skip(3) 20 | self.assertEqual(next(itr), 5) 21 | itr.skip(3) 22 | self.assertEqual(next(itr), 9) 23 | self.assertFalse(itr.has_next()) 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | --------------------------------------------------------------------------------