├── .gitattributes ├── .gitignore ├── README.md ├── annotator ├── .gitattributes ├── .gitignore ├── README.md ├── annotate.py ├── cosine_sim.py ├── data │ ├── ami │ │ ├── test.json │ │ ├── train.json │ │ └── valid.json │ └── samsum │ │ ├── test.json │ │ ├── train.json │ │ └── valid.json ├── get_loss.py ├── get_representation_ami.py ├── get_representation_samsum.py ├── recover_word_loss.py ├── requirements.txt └── utils.py ├── bart ├── .gitattributes ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── config │ ├── config.yaml │ ├── config_eval_lm.yaml │ ├── criterion │ │ ├── adaptive_loss.yaml │ │ └── cross_entropy.yaml │ ├── lr_scheduler │ │ ├── cosine.yaml │ │ └── inverse_sqrt.yaml │ ├── model │ │ ├── transformer_lm.yaml │ │ ├── transformer_lm_baevski_gbw.yaml │ │ ├── transformer_lm_baevski_wiki103.yaml │ │ ├── transformer_lm_big.yaml │ │ ├── transformer_lm_gbw.yaml │ │ ├── transformer_lm_gpt.yaml │ │ ├── transformer_lm_gpt2_big.yaml │ │ ├── transformer_lm_gpt2_medium.yaml │ │ ├── transformer_lm_gpt2_small.yaml │ │ └── transformer_lm_wiki103.yaml │ ├── optimizer │ │ ├── adam.yaml │ │ └── nag.yaml │ ├── params │ │ ├── eval_lm_params.yaml │ │ └── training_params.yaml │ └── task │ │ └── language_modeling.yaml ├── data │ └── .gitignore ├── examples │ ├── .gitignore │ ├── __init__.py │ ├── backtranslation │ │ ├── README.md │ │ ├── deduplicate_lines.py │ │ ├── extract_bt_data.py │ │ ├── prepare-de-monolingual.sh │ │ ├── prepare-wmt18en2de.sh │ │ ├── sacrebleu.sh │ │ └── tokenized_bleu.sh │ ├── bart │ │ ├── README.glue.md │ │ ├── README.md │ │ └── README.summarization.md │ ├── byte_level_bpe │ │ ├── README.md │ │ ├── get_bitext.py │ │ ├── get_data.sh │ │ └── gru_transformer.py │ ├── camembert │ │ └── README.md │ ├── constrained_decoding │ │ ├── README.md │ │ ├── normalize.py │ │ └── tok.py │ ├── conv_seq2seq │ │ └── README.md │ ├── criss │ │ ├── README.md │ │ ├── download_and_preprocess_flores_test.sh │ │ ├── download_and_preprocess_tatoeba.sh │ │ ├── mining │ │ │ ├── mine.py │ │ │ └── mine_example.sh │ │ ├── save_encoder.py │ │ ├── sentence_retrieval │ │ │ ├── encoder_analysis.py │ │ │ └── sentence_retrieval_tatoeba.sh │ │ └── unsupervised_mt │ │ │ └── eval.sh │ ├── cross_lingual_language_model │ │ └── README.md │ ├── joint_alignment_translation │ │ ├── README.md │ │ └── prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh │ ├── language_model │ │ ├── README.adaptive_inputs.md │ │ ├── README.conv.md │ │ ├── README.md │ │ └── prepare-wikitext-103.sh │ ├── latent_depth │ │ ├── README.md │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── loss │ │ │ ├── __init__.py │ │ │ └── latent_depth.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── latent_multilingual_transformer.py │ │ │ └── latent_transformer.py │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── latent_layers.py │ │ │ └── multilingual_translation_latent_depth.py │ ├── layerdrop │ │ └── README.md │ ├── linformer │ │ ├── README.md │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── linformer_roberta.py │ │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── linformer_sentence_encoder.py │ │ │ ├── linformer_sentence_encoder_layer.py │ │ │ └── multihead_linear_attention.py │ ├── m2m_100 │ │ ├── README.md │ │ ├── install_dependecies.sh │ │ ├── process_data │ │ │ ├── clean_histogram.py │ │ │ ├── dedup_data.py │ │ │ └── remove_too_much_punc.py │ │ ├── tok.sh │ │ └── tokenizers │ │ │ ├── README.md │ │ │ ├── seg_ja.sh │ │ │ ├── seg_ko.sh │ │ │ ├── thirdparty │ │ │ └── .gitignore │ │ │ ├── tokenize_indic.py │ │ │ ├── tokenize_thai.py │ │ │ ├── tokenize_zh.py │ │ │ └── tokenizer_ar.sh │ ├── mbart │ │ └── README.md │ ├── megatron_11b │ │ ├── README.md │ │ └── detok.py │ ├── multilingual │ │ ├── README.md │ │ ├── finetune_multilingual_model.sh │ │ ├── multilingual_fairseq_gen.sh │ │ └── train_multilingual_model.sh │ ├── noisychannel │ │ ├── README.md │ │ ├── __init__.py │ │ ├── rerank.py │ │ ├── rerank_generate.py │ │ ├── rerank_options.py │ │ ├── rerank_score_bw.py │ │ ├── rerank_score_lm.py │ │ ├── rerank_tune.py │ │ └── rerank_utils.py │ ├── nonautoregressive_translation │ │ ├── README.md │ │ └── scripts.md │ ├── paraphraser │ │ ├── README.md │ │ └── paraphrase.py │ ├── pay_less_attention_paper │ │ └── README.md │ ├── pointer_generator │ │ ├── README.md │ │ ├── README.xsum.md │ │ ├── postprocess.py │ │ ├── preprocess.py │ │ └── src │ │ │ ├── __init__.py │ │ │ └── transformer_pg.py │ ├── quant_noise │ │ ├── README.md │ │ └── transformer_quantization_config.yaml │ ├── roberta │ │ ├── README.custom_classification.md │ │ ├── README.glue.md │ │ ├── README.md │ │ ├── README.pretraining.md │ │ ├── README.race.md │ │ ├── commonsense_qa │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── commonsense_qa_task.py │ │ │ └── download_cqa_data.sh │ │ ├── multiprocessing_bpe_encoder.py │ │ ├── preprocess_GLUE_tasks.sh │ │ ├── preprocess_RACE.py │ │ ├── preprocess_RACE.sh │ │ └── wsc │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── wsc_criterion.py │ │ │ ├── wsc_task.py │ │ │ └── wsc_utils.py │ ├── rxf │ │ ├── README.md │ │ ├── __init__.py │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── label_smoothed_cross_entropy_r3f.py │ │ │ └── sentence_prediction_r3f.py │ ├── scaling_nmt │ │ └── README.md │ ├── simultaneous_translation │ │ ├── README.md │ │ ├── __init__.py │ │ ├── criterions │ │ │ ├── __init__.py │ │ │ └── label_smoothed_cross_entropy_latency_augmented.py │ │ ├── docs │ │ │ ├── baseline.md │ │ │ └── evaluation.md │ │ ├── eval │ │ │ ├── __init__.py │ │ │ ├── agents │ │ │ │ ├── __init__.py │ │ │ │ ├── agent.py │ │ │ │ ├── simul_trans_agent.py │ │ │ │ ├── simul_trans_text_agent.py │ │ │ │ └── word_splitter.py │ │ │ ├── client.py │ │ │ ├── eval_latency.py │ │ │ ├── evaluate.py │ │ │ ├── scorers │ │ │ │ ├── __init__.py │ │ │ │ ├── scorer.py │ │ │ │ └── text_scorer.py │ │ │ └── server.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── transformer_monotonic_attention.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── monotonic_multihead_attention.py │ │ │ └── monotonic_transformer_layer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── functions.py │ │ │ └── latency.py │ ├── speech_recognition │ │ ├── README.md │ │ ├── __init__.py │ │ ├── criterions │ │ │ ├── ASG_loss.py │ │ │ ├── __init__.py │ │ │ └── cross_entropy_acc.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── asr_dataset.py │ │ │ ├── collaters.py │ │ │ ├── data_utils.py │ │ │ └── replabels.py │ │ ├── datasets │ │ │ ├── asr_prep_json.py │ │ │ └── prepare-librispeech.sh │ │ ├── infer.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── vggtransformer.py │ │ │ └── w2l_conv_glu_enc.py │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ └── speech_recognition.py │ │ ├── utils │ │ │ └── wer_utils.py │ │ └── w2l_decoder.py │ ├── speech_to_text │ │ ├── README.md │ │ ├── data_utils.py │ │ ├── prep_covost_data.py │ │ ├── prep_librispeech_data.py │ │ └── prep_mustc_data.py │ ├── stories │ │ └── README.md │ ├── translation │ │ ├── README.md │ │ ├── prepare-iwslt14.sh │ │ ├── prepare-iwslt17-multilingual.sh │ │ ├── prepare-wmt14en2de.sh │ │ └── prepare-wmt14en2fr.sh │ ├── translation_moe │ │ ├── README.md │ │ ├── score.py │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── logsumexp_moe.py │ │ │ ├── mean_pool_gating_network.py │ │ │ └── translation_moe.py │ ├── unsupervised_quality_estimation │ │ ├── README.md │ │ ├── aggregate_scores.py │ │ ├── meteor.py │ │ └── repeat_lines.py │ ├── wav2vec │ │ ├── README.md │ │ ├── libri_labels.py │ │ ├── vq-wav2vec_featurize.py │ │ ├── wav2vec_featurize.py │ │ └── wav2vec_manifest.py │ ├── wmt19 │ │ └── README.md │ └── xlmr │ │ └── README.md ├── fairseq │ ├── __init__.py │ ├── __init__.pyc │ ├── benchmark │ │ ├── __init__.py │ │ ├── dummy_lm.py │ │ ├── dummy_masked_lm.py │ │ ├── dummy_model.py │ │ └── dummy_mt.py │ ├── binarizer.py │ ├── checkpoint_utils.py │ ├── clib │ │ ├── libbleu │ │ │ ├── libbleu.cpp │ │ │ └── module.cpp │ │ ├── libnat │ │ │ └── edit_dist.cpp │ │ └── libnat_cuda │ │ │ ├── binding.cpp │ │ │ ├── edit_dist.cu │ │ │ └── edit_dist.h │ ├── criterions │ │ ├── __init__.py │ │ ├── adaptive_loss.py │ │ ├── composite_loss.py │ │ ├── cross_entropy.py │ │ ├── ctc.py │ │ ├── fairseq_criterion.py │ │ ├── label_smoothed_cross_entropy.py │ │ ├── label_smoothed_cross_entropy_with_alignment.py │ │ ├── legacy_masked_lm.py │ │ ├── masked_lm.py │ │ ├── nat_loss.py │ │ ├── sentence_prediction.py │ │ ├── sentence_ranking.py │ │ └── wav2vec_criterion.py │ ├── data │ │ ├── __init__.py │ │ ├── add_target_dataset.py │ │ ├── append_token_dataset.py │ │ ├── audio │ │ │ ├── __init__.py │ │ │ └── raw_audio_dataset.py │ │ ├── backtranslation_dataset.py │ │ ├── base_wrapper_dataset.py │ │ ├── bucket_pad_length_dataset.py │ │ ├── colorize_dataset.py │ │ ├── concat_dataset.py │ │ ├── concat_sentences_dataset.py │ │ ├── data_utils.py │ │ ├── data_utils_fast.cpp │ │ ├── data_utils_fast.cpython-36m-darwin.so │ │ ├── data_utils_fast.pyx │ │ ├── denoising_dataset.py │ │ ├── dictionary.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── byte_bpe.py │ │ │ ├── byte_utils.py │ │ │ ├── bytes.py │ │ │ ├── characters.py │ │ │ ├── fastbpe.py │ │ │ ├── gpt2_bpe.py │ │ │ ├── gpt2_bpe_utils.py │ │ │ ├── hf_bert_bpe.py │ │ │ ├── hf_byte_bpe.py │ │ │ ├── moses_tokenizer.py │ │ │ ├── nltk_tokenizer.py │ │ │ ├── sentencepiece_bpe.py │ │ │ ├── space_tokenizer.py │ │ │ ├── subword_nmt_bpe.py │ │ │ └── utils.py │ │ ├── fairseq_dataset.py │ │ ├── fasta_dataset.py │ │ ├── id_dataset.py │ │ ├── indexed_dataset.py │ │ ├── iterators.py │ │ ├── language_pair_dataset.py │ │ ├── legacy │ │ │ ├── __init__.py │ │ │ ├── block_pair_dataset.py │ │ │ ├── masked_lm_dataset.py │ │ │ └── masked_lm_dictionary.py │ │ ├── list_dataset.py │ │ ├── lm_context_window_dataset.py │ │ ├── lru_cache_dataset.py │ │ ├── mask_tokens_dataset.py │ │ ├── monolingual_dataset.py │ │ ├── multi_corpus_dataset.py │ │ ├── multi_corpus_sampled_dataset.py │ │ ├── multilingual │ │ │ ├── __init__.py │ │ │ ├── multilingual_data_manager.py │ │ │ ├── multilingual_utils.py │ │ │ ├── sampled_multi_dataset.py │ │ │ ├── sampled_multi_epoch_dataset.py │ │ │ └── sampling_method.py │ │ ├── nested_dictionary_dataset.py │ │ ├── noising.py │ │ ├── num_samples_dataset.py │ │ ├── numel_dataset.py │ │ ├── offset_tokens_dataset.py │ │ ├── pad_dataset.py │ │ ├── plasma_utils.py │ │ ├── prepend_dataset.py │ │ ├── prepend_token_dataset.py │ │ ├── raw_label_dataset.py │ │ ├── replace_dataset.py │ │ ├── resampling_dataset.py │ │ ├── roll_dataset.py │ │ ├── round_robin_zip_datasets.py │ │ ├── shorten_dataset.py │ │ ├── sort_dataset.py │ │ ├── strip_token_dataset.py │ │ ├── subsample_dataset.py │ │ ├── token_block_dataset.py │ │ ├── token_block_utils_fast.cpp │ │ ├── token_block_utils_fast.cpython-36m-darwin.so │ │ ├── token_block_utils_fast.pyx │ │ ├── transform_eos_dataset.py │ │ └── transform_eos_lang_pair_dataset.py │ ├── dataclass │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── data_class.py │ │ └── utils.py │ ├── distributed_utils.py │ ├── file_io.py │ ├── file_utils.py │ ├── hub_utils.py │ ├── incremental_decoding_utils.py │ ├── iterative_refinement_generator.py │ ├── legacy_distributed_data_parallel.py │ ├── libbleu.cpython-36m-darwin.so │ ├── libnat.cpython-36m-darwin.so │ ├── logging │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── meters.py │ │ ├── metrics.py │ │ └── progress_bar.py │ ├── model_parallel │ │ ├── __init__.py │ │ ├── criterions │ │ │ ├── __init__.py │ │ │ └── vocab_parallel_cross_entropy.py │ │ ├── megatron_trainer.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── pipeline_parallel_transformer │ │ │ │ ├── __init__.py │ │ │ │ ├── layers.py │ │ │ │ └── model.py │ │ │ ├── roberta │ │ │ │ ├── __init__.py │ │ │ │ └── model.py │ │ │ ├── transformer.py │ │ │ └── transformer_lm.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── multihead_attention.py │ │ │ ├── transformer_layer.py │ │ │ ├── transformer_sentence_encoder.py │ │ │ └── transformer_sentence_encoder_layer.py │ ├── models │ │ ├── __init__.py │ │ ├── bart │ │ │ ├── __init__.py │ │ │ ├── hub_interface.py │ │ │ └── model.py │ │ ├── composite_encoder.py │ │ ├── distributed_fairseq_model.py │ │ ├── fairseq_decoder.py │ │ ├── fairseq_encoder.py │ │ ├── fairseq_incremental_decoder.py │ │ ├── fairseq_model.py │ │ ├── fconv.py │ │ ├── fconv_lm.py │ │ ├── fconv_self_att.py │ │ ├── huggingface │ │ │ ├── __init__.py │ │ │ └── hf_gpt2.py │ │ ├── lightconv.py │ │ ├── lightconv_lm.py │ │ ├── lstm.py │ │ ├── lstm_lm.py │ │ ├── masked_lm.py │ │ ├── model_utils.py │ │ ├── multilingual_transformer.py │ │ ├── nat │ │ │ ├── __init__.py │ │ │ ├── cmlm_transformer.py │ │ │ ├── fairseq_nat_model.py │ │ │ ├── insertion_transformer.py │ │ │ ├── iterative_nonautoregressive_transformer.py │ │ │ ├── levenshtein_transformer.py │ │ │ ├── levenshtein_utils.py │ │ │ ├── nat_crf_transformer.py │ │ │ ├── nonautoregressive_ensembles.py │ │ │ └── nonautoregressive_transformer.py │ │ ├── roberta │ │ │ ├── __init__.py │ │ │ ├── alignment_utils.py │ │ │ ├── hub_interface.py │ │ │ ├── model.py │ │ │ ├── model_camembert.py │ │ │ └── model_xlmr.py │ │ ├── transformer.py │ │ ├── transformer_align.py │ │ ├── transformer_from_pretrained_xlm.py │ │ ├── transformer_lm.py │ │ └── wav2vec │ │ │ ├── __init__.py │ │ │ ├── wav2vec.py │ │ │ ├── wav2vec2.py │ │ │ └── wav2vec2_asr.py │ ├── modules │ │ ├── __init__.py │ │ ├── adaptive_input.py │ │ ├── adaptive_softmax.py │ │ ├── beamable_mm.py │ │ ├── character_token_embedder.py │ │ ├── conv_tbc.py │ │ ├── cross_entropy.py │ │ ├── cuda_utils.cu │ │ ├── downsampled_multihead_attention.py │ │ ├── dynamic_convolution.py │ │ ├── dynamic_crf_layer.py │ │ ├── dynamicconv_layer │ │ │ ├── __init__.py │ │ │ ├── cuda_function_gen.py │ │ │ ├── dynamicconv_cuda.cpp │ │ │ ├── dynamicconv_cuda.cuh │ │ │ ├── dynamicconv_cuda_kernel.cu │ │ │ ├── dynamicconv_layer.py │ │ │ ├── dynamiconv_cpu.cpp │ │ │ └── setup.py │ │ ├── fairseq_dropout.py │ │ ├── fp32_group_norm.py │ │ ├── gelu.py │ │ ├── grad_multiply.py │ │ ├── gumbel_vector_quantizer.py │ │ ├── kmeans_vector_quantizer.py │ │ ├── layer_drop.py │ │ ├── layer_norm.py │ │ ├── learned_positional_embedding.py │ │ ├── lightconv_layer │ │ │ ├── __init__.py │ │ │ ├── cuda_function_gen.py │ │ │ ├── lightconv_cuda.cpp │ │ │ ├── lightconv_cuda.cuh │ │ │ ├── lightconv_cuda_kernel.cu │ │ │ ├── lightconv_layer.py │ │ │ └── setup.py │ │ ├── lightweight_convolution.py │ │ ├── linearized_convolution.py │ │ ├── multihead_attention.py │ │ ├── positional_embedding.py │ │ ├── quant_noise.py │ │ ├── quantization │ │ │ ├── __init__.py │ │ │ ├── pq │ │ │ │ ├── __init__.py │ │ │ │ ├── em.py │ │ │ │ ├── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── qconv.py │ │ │ │ │ ├── qemb.py │ │ │ │ │ └── qlinear.py │ │ │ │ ├── pq.py │ │ │ │ └── utils.py │ │ │ ├── quantization_options.py │ │ │ └── scalar │ │ │ │ ├── __init__.py │ │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── qact.py │ │ │ │ ├── qconv.py │ │ │ │ ├── qemb.py │ │ │ │ └── qlinear.py │ │ │ │ ├── ops.py │ │ │ │ └── utils.py │ │ ├── same_pad.py │ │ ├── scalar_bias.py │ │ ├── sinusoidal_positional_embedding.py │ │ ├── sparse_multihead_attention.py │ │ ├── sparse_transformer_sentence_encoder.py │ │ ├── sparse_transformer_sentence_encoder_layer.py │ │ ├── transformer_layer.py │ │ ├── transformer_sentence_encoder.py │ │ ├── transformer_sentence_encoder_layer.py │ │ ├── transpose_last.py │ │ ├── unfold.py │ │ └── vggblock.py │ ├── nan_detector.py │ ├── optim │ │ ├── __init__.py │ │ ├── adadelta.py │ │ ├── adafactor.py │ │ ├── adagrad.py │ │ ├── adam.py │ │ ├── adamax.py │ │ ├── bmuf.py │ │ ├── dynamic_loss_scaler.py │ │ ├── fairseq_optimizer.py │ │ ├── fp16_optimizer.py │ │ ├── fused_adam.py │ │ ├── fused_lamb.py │ │ ├── lr_scheduler │ │ │ ├── __init__.py │ │ │ ├── cosine_lr_scheduler.py │ │ │ ├── fairseq_lr_scheduler.py │ │ │ ├── fixed_schedule.py │ │ │ ├── inverse_square_root_schedule.py │ │ │ ├── polynomial_decay_schedule.py │ │ │ ├── reduce_lr_on_plateau.py │ │ │ ├── tri_stage_lr_scheduler.py │ │ │ └── triangular_lr_scheduler.py │ │ ├── nag.py │ │ ├── sgd.py │ │ └── shard.py │ ├── options.py │ ├── pdb.py │ ├── quantization_utils.py │ ├── registry.py │ ├── scoring │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── wer.py │ ├── search.py │ ├── sequence_generator.py │ ├── sequence_scorer.py │ ├── tasks │ │ ├── __init__.py │ │ ├── audio_pretraining.py │ │ ├── cross_lingual_lm.py │ │ ├── denoising.py │ │ ├── fairseq_task.py │ │ ├── language_modeling.py │ │ ├── legacy_masked_lm.py │ │ ├── masked_lm.py │ │ ├── multilingual_denoising.py │ │ ├── multilingual_masked_lm.py │ │ ├── multilingual_translation.py │ │ ├── semisupervised_translation.py │ │ ├── sentence_prediction.py │ │ ├── sentence_ranking.py │ │ ├── translation.py │ │ ├── translation_from_pretrained_bart.py │ │ ├── translation_from_pretrained_xlm.py │ │ ├── translation_lev.py │ │ └── translation_multi_simple_epoch.py │ ├── token_generation_constraints.py │ ├── tokenizer.py │ ├── trainer.py │ └── utils.py ├── fairseq_cli │ ├── __init__.py │ ├── eval_lm.py │ ├── generate.py │ ├── inference.py │ ├── interactive.py │ ├── preprocess.py │ ├── score.py │ ├── train.py │ └── validate.py ├── my_scripts │ ├── __init__.py │ ├── binarize.sh │ ├── bpe.sh │ ├── infer.sh │ └── train.sh ├── py_rouge_test.py ├── requirements.txt ├── scripts │ ├── __init__.py │ ├── average_checkpoints.py │ ├── build_sym_alignment.py │ ├── compare_namespaces.py │ ├── compound_split_bleu.sh │ ├── constraints │ │ ├── extract.py │ │ └── validate.py │ ├── convert_dictionary.lua │ ├── convert_model.lua │ ├── count_docs.py │ ├── read_binarized.py │ ├── rm_pt.py │ ├── sacrebleu.sh │ ├── shard_docs.py │ ├── split_train_valid_docs.py │ ├── spm_decode.py │ ├── spm_encode.py │ └── spm_train.py ├── setup.py ├── summaries │ └── samsum.txt └── train.py ├── pgn ├── .gitattributes ├── .gitignore ├── README.md ├── ckpt │ └── .gitignore ├── data │ └── .gitignore ├── embeddings_to_torch.py ├── logs │ └── .gitignore ├── onmt │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── model_builder.cpython-37.pyc │ │ ├── opts.cpython-37.pyc │ │ ├── train_single.cpython-37.pyc │ │ └── trainer.cpython-37.pyc │ ├── bin │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── preprocess.cpython-37.pyc │ │ │ ├── train.cpython-37.pyc │ │ │ └── translate.cpython-37.pyc │ │ ├── preprocess.py │ │ ├── server.py │ │ ├── train.py │ │ └── translate.py │ ├── decoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── cnn_decoder.cpython-37.pyc │ │ │ ├── decoder.cpython-37.pyc │ │ │ ├── ensemble.cpython-37.pyc │ │ │ └── transformer.cpython-37.pyc │ │ ├── cnn_decoder.py │ │ ├── decoder.py │ │ ├── ensemble.py │ │ └── transformer.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── audio_encoder.cpython-37.pyc │ │ │ ├── cnn_encoder.cpython-37.pyc │ │ │ ├── encoder.cpython-37.pyc │ │ │ ├── image_encoder.cpython-37.pyc │ │ │ ├── mean_encoder.cpython-37.pyc │ │ │ ├── rnn_encoder.cpython-37.pyc │ │ │ └── transformer.cpython-37.pyc │ │ ├── audio_encoder.py │ │ ├── cnn_encoder.py │ │ ├── encoder.py │ │ ├── image_encoder.py │ │ ├── mean_encoder.py │ │ ├── rnn_encoder.py │ │ └── transformer.py │ ├── inputters │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── audio_dataset.cpython-37.pyc │ │ │ ├── datareader_base.cpython-37.pyc │ │ │ ├── dataset_base.cpython-37.pyc │ │ │ ├── image_dataset.cpython-37.pyc │ │ │ ├── inputter.cpython-37.pyc │ │ │ ├── text_dataset.cpython-37.pyc │ │ │ └── vec_dataset.cpython-37.pyc │ │ ├── audio_dataset.py │ │ ├── datareader_base.py │ │ ├── dataset_base.py │ │ ├── image_dataset.py │ │ ├── inputter.py │ │ ├── text_dataset.py │ │ └── vec_dataset.py │ ├── model_builder.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── model.cpython-37.pyc │ │ │ ├── model_saver.cpython-37.pyc │ │ │ ├── sru.cpython-37.pyc │ │ │ └── stacked_rnn.cpython-37.pyc │ │ ├── model.py │ │ ├── model_saver.py │ │ ├── sru.py │ │ └── stacked_rnn.py │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── average_attn.cpython-37.pyc │ │ │ ├── conv_multi_step_attention.cpython-37.pyc │ │ │ ├── copy_generator.cpython-37.pyc │ │ │ ├── embeddings.cpython-37.pyc │ │ │ ├── gate.cpython-37.pyc │ │ │ ├── global_attention.cpython-37.pyc │ │ │ ├── multi_headed_attn.cpython-37.pyc │ │ │ ├── position_ffn.cpython-37.pyc │ │ │ ├── sparse_activations.cpython-37.pyc │ │ │ ├── sparse_losses.cpython-37.pyc │ │ │ ├── util_class.cpython-37.pyc │ │ │ └── weight_norm.cpython-37.pyc │ │ ├── average_attn.py │ │ ├── conv_multi_step_attention.py │ │ ├── copy_generator.py │ │ ├── embeddings.py │ │ ├── gate.py │ │ ├── global_attention.py │ │ ├── multi_headed_attn.py │ │ ├── position_ffn.py │ │ ├── sparse_activations.py │ │ ├── sparse_losses.py │ │ ├── structured_attention.py │ │ ├── util_class.py │ │ └── weight_norm.py │ ├── opts.py │ ├── train_single.py │ ├── trainer.py │ ├── translate │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── beam.cpython-37.pyc │ │ │ ├── beam_search.cpython-37.pyc │ │ │ ├── decode_strategy.cpython-37.pyc │ │ │ ├── penalties.cpython-37.pyc │ │ │ ├── random_sampling.cpython-37.pyc │ │ │ ├── translation.cpython-37.pyc │ │ │ ├── translation_server.cpython-37.pyc │ │ │ └── translator.cpython-37.pyc │ │ ├── beam.py │ │ ├── beam_search.py │ │ ├── decode_strategy.py │ │ ├── penalties.py │ │ ├── process_zh.py │ │ ├── random_sampling.py │ │ ├── translation.py │ │ ├── translation_server.py │ │ └── translator.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── cnn_factory.cpython-37.pyc │ │ ├── distributed.cpython-37.pyc │ │ ├── earlystopping.cpython-37.pyc │ │ ├── logging.cpython-37.pyc │ │ ├── loss.cpython-37.pyc │ │ ├── misc.cpython-37.pyc │ │ ├── optimizers.cpython-37.pyc │ │ ├── parse.cpython-37.pyc │ │ ├── report_manager.cpython-37.pyc │ │ ├── rnn_factory.cpython-37.pyc │ │ └── statistics.cpython-37.pyc │ │ ├── cnn_factory.py │ │ ├── distributed.py │ │ ├── earlystopping.py │ │ ├── logging.py │ │ ├── loss.py │ │ ├── misc.py │ │ ├── optimizers.py │ │ ├── parse.py │ │ ├── report_manager.py │ │ ├── rnn_factory.py │ │ └── statistics.py ├── preprocess.py ├── requirements.txt ├── scripts │ ├── embedding.sh │ ├── infer.sh │ ├── preprocess.sh │ └── train.sh ├── summaries │ └── ami.txt ├── test_rouge.py ├── train.py └── translate.py └── pic └── main.png /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /annotator/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /annotator/README.md: -------------------------------------------------------------------------------- 1 | # DialoGPT Annotator 2 | 3 | This code mainly shows how we annotate a dialogue from AMI and SAMSum using [DialoGPT](https://arxiv.org/abs/1911.00536). 4 | 5 | ## Requirements 6 | * `conda create -n tfs python=3.7` 7 | * `pip install -r requirements.txt` 8 | 9 | ## Get loss 10 | Firstly, run the following command, you will get a dir ***loss/samsum/bpe*** or ***loss/ami/bpe*** that stores three files: ***train_loss.json***, ***valid_loss.json*** and ***test_loss.json***. 11 | * For SAMSum: `python get_loss.py -d samsum` 12 | * For AMI: `python get_loss.py -d ami` 13 | 14 | Secondly, we recover word-level loss, you will get a dir ***loss/samsum/word*** or ***loss/ami/word*** 15 | > Note that DialoGPT uses BPE to tokenize texts, thus, losses are calculated at the sub-word level. We recover the word-level predicted loss by averaging the losses of multiple sub-words. 16 | * For SAMSum: `python recover_word_loss.py -d samsum` 17 | * For AMI: `python recover_word_loss.py -d ami` 18 | 19 | ## Get dialogue context representation 20 | Run following commands, you will get a dir ***rep/samsum*** or ***rep/ami*** that stores three files: ***train_rep.json***, ***valid_rep.json*** and ***test_rep.json***. 21 | * For SAMSum: `python get_representation_samsum.py` 22 | * For AMI: `python get_representation_ami.py` 23 | 24 | ## Calculate cosine similarity 25 | Run following commands, you will get a dir ***rep/samsum/sim*** or ***rep/samsum/sim*** that stores three files: ***train_sim.json***, ***valid_sim.json*** and ***test_sim.json***. 26 | * For SAMSum: `python cosine_sim.py -d samsum` 27 | * For AMI: `python cosine_sim.py -d ami` 28 | 29 | ## Annotate 30 | Run following commands, you will get a dir ***data/samsum/final*** or ***data/ami/final*** that stores final output files. 31 | * For SAMSum: `python annotate.py -d samsum` 32 | * For AMI: `python annotate.py -d ami` -------------------------------------------------------------------------------- /annotator/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.0 2 | numpy==1.19.4 3 | sacremoses==0.0.43 4 | sentencepiece==0.1.91 5 | six==1.15.0 6 | stanza==1.1.1 7 | tokenizers==0.9.4 8 | tqdm==4.55.1 9 | transformers==3.5.1 -------------------------------------------------------------------------------- /bart/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /bart/.gitignore: -------------------------------------------------------------------------------- 1 | idea/ 2 | __pycache__ -------------------------------------------------------------------------------- /bart/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fairseq/models/huggingface/transformers"] 2 | path = fairseq/models/huggingface/transformers 3 | url = https://github.com/myleott/transformers.git 4 | branch = fairseq 5 | [submodule "fairseq/model_parallel/megatron"] 6 | path = fairseq/model_parallel/megatron 7 | url = https://github.com/ngoyal2707/Megatron-LM 8 | branch = fairseq 9 | -------------------------------------------------------------------------------- /bart/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bart/README.md: -------------------------------------------------------------------------------- 1 | # BART for SAMSum Dataset 2 | 3 | This code is based on [Fairseq](https://github.com/pytorch/fairseq). 4 | 5 | ## Requirements 6 | * We use Conda python 3.7 and strongly recommend that you create a new environment. 7 | * `conda create -n bart python=3.7`. 8 | * Run the following command. 9 | * `pip install --editable ./` 10 | * `pip install -r requirements.txt` 11 | 12 | ## Data 13 | You can get data [here](https://drive.google.com/drive/folders/1wLea1LdEv1jFQMtXr3bXJFiKvZnrGQLO?usp=sharing). Put them under the dir **data/\***. 14 | 15 | ## Reproduce Results 16 | You can follow the following steps to reproduce the best results in our paper. 17 | 18 | ### download checkpoints 19 | Download checkpoints [here](https://drive.google.com/drive/folders/1Osr3HXUPuGmh6-nCm8eSt_1ISJOaBhxy?usp=sharing). Put the checkpoint under the dir **ckpt/samsum.pt**. 20 | 21 | ### preprocess 22 | * `sh ./my_scripts/bpe.sh` 23 | * `sh ./my_scripts/binarize.sh` 24 | 25 | ### translate 26 | * Produce final summaries. 27 | * `sh ./my_scripts/infer.sh` 28 | 29 | ### test rouge score 30 | * `python py_rouge_test.py -c summaries/samsum.txt` 31 | 32 | ### ROUGE score 33 | ||ROUGE-1| ROUGE-2 | ROUGE-L | 34 | | :---: | :---: | :---: | :---: | 35 | | SAMSum | 53.70 | 28.79 | 50.81| 36 | 37 | ## From Scratch 38 | 39 | ### Download BART checkpoint 40 | Download [bart.large](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz) checkpoint. Put it under the dir **bart/bart.large/\***. 41 | 42 | ### Preprocess 43 | * `sh ./my_scripts/bpe.sh` 44 | * `sh ./my_scripts/binarize.sh` 45 | 46 | ### Train 47 | * `sh ./my_scripts/train.sh` 48 | 49 | ### Translate 50 | Run the following command: 51 | * `sh ./my_scripts/infer.sh` 52 | * set up **ckpt_dir** param first. 53 | -------------------------------------------------------------------------------- /bart/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - params: training_params 3 | - task: language_modeling 4 | - model: transformer_lm 5 | - criterion: cross_entropy 6 | - optimizer: adam 7 | - lr_scheduler: inverse_sqrt 8 | -------------------------------------------------------------------------------- /bart/config/config_eval_lm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - params: eval_lm_params 3 | - task: language_modeling 4 | - model: transformer_lm 5 | - criterion: cross_entropy 6 | - optimizer: adam 7 | - lr_scheduler: inverse_sqrt 8 | -------------------------------------------------------------------------------- /bart/config/criterion/adaptive_loss.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | sentence_avg: ${params.optimization.sentence_avg} 3 | ddp_backend: ${params.distributed_training.ddp_backend} 4 | -------------------------------------------------------------------------------- /bart/config/criterion/cross_entropy.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | sentence_avg: ${params.optimization.sentence_avg} 3 | ddp_backend: ${params.distributed_training.ddp_backend} 4 | -------------------------------------------------------------------------------- /bart/config/lr_scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | warmup_updates: 0 3 | warmup_init_lr: -1 4 | max_lr: 1.0 5 | t_mult: 1.0 6 | lr_period_updates: -1 7 | lr_shrink: 0.1 8 | -------------------------------------------------------------------------------- /bart/config/lr_scheduler/inverse_sqrt.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | warmup_updates: 4000 3 | warmup_init_lr: -1 4 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.0 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 2048 11 | decoder_layers: 6 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_baevski_gbw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_baevski_wiki103.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.3 4 | attention_dropout: 0.1 5 | activation_dropout: 0.1 6 | relu_dropout: 0.1 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 16 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: "20000,60000" 16 | adaptive_softmax_dropout: 0.2 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: true 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: "20000,60000" 27 | tie_adaptive_weights: true 28 | tie_adaptive_proj: true 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.0 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_gbw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 512 8 | decoder_output_dim: 512 9 | decoder_input_dim: 512 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 12 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_gpt.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 768 8 | decoder_output_dim: 768 9 | decoder_input_dim: 768 10 | decoder_ffn_embed_dim: 3072 11 | decoder_layers: 12 12 | decoder_attention_heads: 12 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_gpt2_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1600 8 | decoder_output_dim: 1600 9 | decoder_input_dim: 1600 10 | decoder_ffn_embed_dim: 6400 11 | decoder_layers: 48 12 | decoder_attention_heads: 25 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_gpt2_medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1280 8 | decoder_output_dim: 1280 9 | decoder_input_dim: 1280 10 | decoder_ffn_embed_dim: 5120 11 | decoder_layers: 36 12 | decoder_attention_heads: 20 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_gpt2_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "gelu" 3 | dropout: 0.1 4 | attention_dropout: 0.1 5 | activation_dropout: 0.0 6 | relu_dropout: 0.0 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 24 12 | decoder_attention_heads: 16 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: false 15 | adaptive_softmax_cutoff: null 16 | adaptive_softmax_dropout: 0 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: false 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: null 27 | tie_adaptive_weights: false 28 | tie_adaptive_proj: false 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/model/transformer_lm_wiki103.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | activation_fn: "relu" 3 | dropout: 0.3 4 | attention_dropout: 0.1 5 | activation_dropout: 0.1 6 | relu_dropout: 0.1 7 | decoder_embed_dim: 1024 8 | decoder_output_dim: 1024 9 | decoder_input_dim: 1024 10 | decoder_ffn_embed_dim: 4096 11 | decoder_layers: 16 12 | decoder_attention_heads: 8 13 | decoder_normalize_before: true 14 | no_decoder_final_norm: true 15 | adaptive_softmax_cutoff: "20000,60000" 16 | adaptive_softmax_dropout: 0.2 17 | adaptive_softmax_factor: 4 18 | no_token_positional_embeddings: false 19 | share_decoder_input_output_embed: false 20 | character_embeddings: false 21 | character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" 22 | character_embedding_dim: 4 23 | char_embedder_highway_layers: 2 24 | adaptive_input: true 25 | adaptive_input_factor: 4 26 | adaptive_input_cutoff: "20000,60000" 27 | tie_adaptive_weights: true 28 | tie_adaptive_proj: true 29 | decoder_learned_pos: false 30 | decoder_layerdrop: 0 31 | decoder_layers_to_keep: null 32 | layernorm_embedding: false 33 | no_scale_embedding: false 34 | quant_noise_pq: 0 35 | quant_noise_pq_block_size: 8 36 | quant_noise_scalar: 0 37 | -------------------------------------------------------------------------------- /bart/config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | adam_betas: "(0.9, 0.999)" 3 | adam_eps: 1.0e-8 4 | weight_decay: 0 5 | use_old_adam: false 6 | -------------------------------------------------------------------------------- /bart/config/optimizer/nag.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | momentum: 0.99 3 | weight_decay: 0.0 4 | -------------------------------------------------------------------------------- /bart/config/task/language_modeling.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data: ??? 3 | sample_break_mode: "none" 4 | tokens_per_sample: 1024 5 | output_dictionary_size: -1 6 | self_target: false 7 | future_target: false 8 | past_target: false 9 | add_bos_token: false 10 | max_target_positions: null 11 | -------------------------------------------------------------------------------- /bart/data/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/data/.gitignore -------------------------------------------------------------------------------- /bart/examples/.gitignore: -------------------------------------------------------------------------------- 1 | !*/*.sh 2 | !*/*.md 3 | -------------------------------------------------------------------------------- /bart/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import __version__ # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/backtranslation/deduplicate_lines.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import fileinput 9 | import hashlib 10 | import sys 11 | from multiprocessing import Pool 12 | 13 | 14 | def get_hashes_and_lines(raw_line): 15 | hash = hashlib.md5(raw_line).hexdigest() 16 | return hash, raw_line 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--workers", type=int, default=10) 22 | parser.add_argument("files", nargs="*", help="input files") 23 | args = parser.parse_args() 24 | 25 | seen = set() 26 | with fileinput.input(args.files, mode="rb") as h: 27 | pool = Pool(args.workers) 28 | results = pool.imap_unordered(get_hashes_and_lines, h, 1000) 29 | for i, (hash, raw_line) in enumerate(results): 30 | if hash not in seen: 31 | seen.add(hash) 32 | sys.stdout.buffer.write(raw_line) 33 | if i % 1000000 == 0: 34 | print(i, file=sys.stderr, end="", flush=True) 35 | elif i % 100000 == 0: 36 | print(".", file=sys.stderr, end="", flush=True) 37 | print(file=sys.stderr, flush=True) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /bart/examples/backtranslation/sacrebleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 5 ]; then 4 | echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]" 5 | exit 6 | fi 7 | 8 | 9 | DATASET=$1 10 | LANGPAIR=$2 11 | DATABIN=$3 12 | BPECODE=$4 13 | MODEL=$5 14 | 15 | SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1) 16 | TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2) 17 | 18 | 19 | BPEROOT=examples/backtranslation/subword-nmt/subword_nmt 20 | if [ ! -e $BPEROOT ]; then 21 | BPEROOT=subword-nmt/subword_nmt 22 | if [ ! -e $BPEROOT ]; then 23 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 24 | git clone https://github.com/rsennrich/subword-nmt.git 25 | fi 26 | fi 27 | 28 | 29 | sacrebleu -t $DATASET -l $LANGPAIR --echo src \ 30 | | sacremoses tokenize -a -l $SRCLANG -q \ 31 | | python $BPEROOT/apply_bpe.py -c $BPECODE \ 32 | | fairseq-interactive $DATABIN --path $MODEL \ 33 | -s $SRCLANG -t $TGTLANG \ 34 | --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \ 35 | | grep ^H- | cut -f 3- \ 36 | | sacremoses detokenize -l $TGTLANG -q \ 37 | | sacrebleu -t $DATASET -l $LANGPAIR 38 | -------------------------------------------------------------------------------- /bart/examples/backtranslation/tokenized_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 5 ]; then 4 | echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]" 5 | exit 6 | fi 7 | 8 | 9 | DATASET=$1 10 | LANGPAIR=$2 11 | DATABIN=$3 12 | BPECODE=$4 13 | MODEL=$5 14 | 15 | SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1) 16 | TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2) 17 | 18 | 19 | BPEROOT=examples/backtranslation/subword-nmt/subword_nmt 20 | if [ ! -e $BPEROOT ]; then 21 | BPEROOT=subword-nmt/subword_nmt 22 | if [ ! -e $BPEROOT ]; then 23 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 24 | git clone https://github.com/rsennrich/subword-nmt.git 25 | fi 26 | fi 27 | 28 | 29 | TMP_REF=$(mktemp) 30 | 31 | sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \ 32 | | sacremoses normalize -l $TGTLANG -q \ 33 | | sacremoses tokenize -a -l $TGTLANG -q \ 34 | > $TMP_REF 35 | 36 | sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \ 37 | | sacremoses normalize -l $SRCLANG -q \ 38 | | sacremoses tokenize -a -l $SRCLANG -q \ 39 | | python $BPEROOT/apply_bpe.py -c $BPECODE \ 40 | | fairseq-interactive $DATABIN --path $MODEL \ 41 | -s $SRCLANG -t $TGTLANG \ 42 | --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \ 43 | | grep ^H- | cut -f 3- \ 44 | | fairseq-score --ref $TMP_REF 45 | 46 | rm -f $TMP_REF 47 | -------------------------------------------------------------------------------- /bart/examples/constrained_decoding/normalize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import sys 9 | 10 | from sacremoses.normalize import MosesPunctNormalizer 11 | 12 | 13 | def main(args): 14 | normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn) 15 | for line in sys.stdin: 16 | print(normalizer.normalize(line.rstrip()), flush=True) 17 | 18 | 19 | if __name__ == "__main__": 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--lang", "-l", default="en") 24 | parser.add_argument("--penn", "-p", action="store_true") 25 | args = parser.parse_args() 26 | 27 | main(args) 28 | -------------------------------------------------------------------------------- /bart/examples/constrained_decoding/tok.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import sys 9 | 10 | import sacremoses 11 | 12 | 13 | def main(args): 14 | """Tokenizes, preserving tabs""" 15 | mt = sacremoses.MosesTokenizer(lang=args.lang) 16 | 17 | def tok(s): 18 | return mt.tokenize(s, return_str=True) 19 | 20 | for line in sys.stdin: 21 | parts = list(map(tok, line.split("\t"))) 22 | print(*parts, sep="\t", flush=True) 23 | 24 | 25 | if __name__ == "__main__": 26 | import argparse 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--lang", "-l", default="en") 30 | parser.add_argument("--penn", "-p", action="store_true") 31 | parser.add_argument("--fields", "-f", help="fields to tokenize") 32 | args = parser.parse_args() 33 | 34 | main(args) 35 | -------------------------------------------------------------------------------- /bart/examples/language_model/README.conv.md: -------------------------------------------------------------------------------- 1 | # Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017) 2 | 3 | ## Example usage 4 | 5 | First download and preprocess the data following the main [language modeling README](README.md). 6 | 7 | Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103` 8 | architecture: 9 | ```bash 10 | fairseq-train --task language_modeling \ 11 | data-bin/wikitext-103 \ 12 | --save-dir checkpoints/fconv_wikitext-103 \ 13 | --arch fconv_lm_dauphin_wikitext103 \ 14 | --adaptive-softmax-cutoff 10000,20000,200000 \ 15 | --dropout 0.2 \ 16 | --criterion adaptive_loss \ 17 | --optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \ 18 | --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ 19 | --max-tokens 1024 --tokens-per-sample 1024 \ 20 | --ddp-backend no_c10d \ 21 | --max-epoch 35 22 | ``` 23 | 24 | And evaluate with: 25 | ```bash 26 | fairseq-eval-lm data-bin/wikitext-103 --path checkpoints/fconv_wiki103/checkpoint_best.pt 27 | ``` 28 | 29 | ## Citation 30 | 31 | ```bibtex 32 | @inproceedings{dauphin2017language, 33 | title={Language Modeling with Gated Convolutional Networks}, 34 | author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David}, 35 | booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70}, 36 | pages={933--941}, 37 | year={2017}, 38 | organization={JMLR} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /bart/examples/language_model/prepare-wikitext-103.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | URLS=( 5 | "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" 6 | ) 7 | FILES=( 8 | "wikitext-103-v1.zip" 9 | ) 10 | 11 | for ((i=0;i<${#URLS[@]};++i)); do 12 | file=${FILES[i]} 13 | if [ -f $file ]; then 14 | echo "$file already exists, skipping download" 15 | else 16 | url=${URLS[i]} 17 | wget "$url" 18 | if [ -f $file ]; then 19 | echo "$url successfully downloaded." 20 | else 21 | echo "$url not successfully downloaded." 22 | exit -1 23 | fi 24 | if [ ${file: -4} == ".tgz" ]; then 25 | tar zxvf $file 26 | elif [ ${file: -4} == ".tar" ]; then 27 | tar xvf $file 28 | elif [ ${file: -4} == ".zip" ]; then 29 | unzip $file 30 | fi 31 | fi 32 | done 33 | cd .. 34 | -------------------------------------------------------------------------------- /bart/examples/latent_depth/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import multilingual_translation_latent_depth # noqa 7 | from .loss import latent_depth # noqa 8 | from .models import latent_multilingual_transformer # noqa 9 | from .modules import latent_layers # noqa 10 | -------------------------------------------------------------------------------- /bart/examples/latent_depth/src/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/examples/latent_depth/src/loss/__init__.py -------------------------------------------------------------------------------- /bart/examples/latent_depth/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/examples/latent_depth/src/models/__init__.py -------------------------------------------------------------------------------- /bart/examples/latent_depth/src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/examples/latent_depth/src/modules/__init__.py -------------------------------------------------------------------------------- /bart/examples/linformer/README.md: -------------------------------------------------------------------------------- 1 | # Linformer: Self-Attention with Linear Complexity (Wang et al., 2020) 2 | 3 | This example contains code to train Linformer models as described in our paper 4 | [Linformer: Self-Attention with Linear Complexity](https://arxiv.org/abs/2006.04768). 5 | 6 | ## Training a new Linformer RoBERTa model 7 | 8 | You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md), 9 | updating your training command with `--user-dir examples/linformer/src --arch linformer_roberta_base`. 10 | 11 | ## Citation 12 | 13 | If you use our work, please cite: 14 | 15 | ```bibtex 16 | @article{wang2020linformer, 17 | title={Linformer: Self-Attention with Linear Complexity}, 18 | author={Wang, Sinong and Li, Belinda and Khabsa, Madian and Fang, Han and Ma, Hao}, 19 | journal={arXiv preprint arXiv:2006.04768}, 20 | year={2020} 21 | } 22 | ``` 23 | -------------------------------------------------------------------------------- /bart/examples/linformer/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .models import linformer_roberta # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/linformer/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/examples/linformer/src/models/__init__.py -------------------------------------------------------------------------------- /bart/examples/linformer/src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/examples/linformer/src/modules/__init__.py -------------------------------------------------------------------------------- /bart/examples/m2m_100/process_data/remove_too_much_punc.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import argparse 3 | from string import punctuation 4 | 5 | def len_no_punc(s, punc): 6 | return len([ch for ch in s if ch in punc]) 7 | 8 | def filter_overpunc(len_npunc, len_sen): 9 | return len_npunc < 0.5*len_sen 10 | 11 | def main(args): 12 | punc = punctuation + "—|–" 13 | print('Processing file {}'.format(args.input)) 14 | with gzip.open(args.input, 'rt', encoding=args.encoding) as tsv: 15 | with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc: 16 | with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt: 17 | line = tsv.readline() 18 | fields = line.split('\t') 19 | 20 | src, tgt = fields[1], fields[2] 21 | 22 | nchar_npunc_src = len_no_punc(src, punc) 23 | nchar_npunc_tgt = len_no_punc(tgt, punc) 24 | 25 | if filter_overpunc(nchar_npunc_src, len(src)) and filter_overpunc(nchar_npunc_tgt, len(tgt)): 26 | fsrc.write(src.strip() + '\n') 27 | ftgt.write(tgt.strip() + '\n') 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--input", required=True, type=str) 32 | parser.add_argument('--encoding', default='utf-8', help='character encoding for input/output') 33 | parser.add_argument('--bitext', type=str, required=True, help='language direction') 34 | parser.add_argument('--src-lang', type=str, required=True, help='Source language') 35 | parser.add_argument('--tgt-lang', type=str, required=True, help='Target language') 36 | main(parser.parse_args()) 37 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/README.md: -------------------------------------------------------------------------------- 1 | # M2M-100 Tokenization 2 | 3 | We apply different tokenization strategies for different languages following the existing literature. Here we provide tok.sh a tokenizer that can be used to reproduce our results. 4 | 5 | To reproduce the results, follow these steps: 6 | 7 | ``` 8 | tgt_lang=... 9 | reference_translation=... 10 | cat generation_output | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh $tgt_lang > hyp 11 | cat $reference_translation |sh tok.sh $tgt_lang > ref 12 | sacrebleu -tok 'none' ref < hyp 13 | ``` 14 | 15 | ## Installation 16 | 17 | Tools needed for all the languages except Arabic can be installed by running install_dependencies.sh 18 | If you want to evaluate Arabic models, please follow the instructions provided here: http://alt.qcri.org/tools/arabic-normalizer/ to install 19 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/seg_ja.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | SCRIPT=`realpath $0` 7 | KYTEA=`dirname $SCRIPT`/thirdparty/kytea 8 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$KYTEA/lib:/usr/local/lib 9 | export PATH=$PATH:"$KYTEA/bin" 10 | 11 | cat - | tr -d "[:blank:]" | kytea -notags 12 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/seg_ko.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | SCRIPT=`realpath $0` 7 | MECAB=`dirname $SCRIPT`/thirdparty/mecab-0.996-ko-0.9.2 8 | 9 | export PATH=$PATH:"$MECAB/bin":"$MECAB/lib" 10 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$MECAB/lib" 11 | 12 | cat - | mecab -O wakati 13 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/thirdparty/.gitignore: -------------------------------------------------------------------------------- 1 | seg_my.py 2 | indic_nlp_library/ 3 | indic_nlp_resources/ 4 | kytea/ 5 | mecab-0.996-ko-0.9.2.tar.gz 6 | mecab-0.996-ko-0.9.2/ 7 | mosesdecoder/ 8 | wat2020.my-en.zip 9 | wat2020.my-en/ 10 | wmt16-scripts/ 11 | mecab-ko-dic-2.1.1-20180720/ 12 | mecab-ko-dic-2.1.1-20180720.tar.gz -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/tokenize_indic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Use: echo {text} | python tokenize_indic.py {language} 8 | 9 | import sys 10 | 11 | from indicnlp.normalize.indic_normalize import IndicNormalizerFactory 12 | from indicnlp.tokenize.indic_tokenize import trivial_tokenize 13 | 14 | 15 | factory = IndicNormalizerFactory() 16 | normalizer = factory.get_normalizer( 17 | sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing" 18 | ) 19 | 20 | for line in sys.stdin: 21 | normalized_line = normalizer.normalize(line.strip()) 22 | tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1])) 23 | print(tokenized_line) 24 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/tokenize_thai.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | from pythainlp import word_tokenize 10 | 11 | 12 | for line in sys.stdin: 13 | print(" ".join(word_tokenize(line.strip()))) 14 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/tokenize_zh.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import fileinput 9 | 10 | import sacrebleu 11 | 12 | 13 | for line in fileinput.input(): 14 | print(sacrebleu.tokenize_zh(line)) 15 | -------------------------------------------------------------------------------- /bart/examples/m2m_100/tokenizers/tokenizer_ar.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Please follow the instructions here http://alt.qcri.org/tools/arabic-normalizer/ 8 | # to install tools needed for Arabic 9 | 10 | echo "Please install Arabic tools: http://alt.qcri.org/tools/arabic-normalizer/" 11 | echo "Then update environment variables in tokenizer_ar.sh" 12 | exit 1 13 | 14 | SVMTOOL=... 15 | GOMOSESGO=... 16 | QCRI_ARABIC_NORMALIZER=... 17 | 18 | export PERL5LIB="$SVMTOOL/lib":"$GOMOSESGO/bin/MADA-3.2":$PERL5LIB 19 | 20 | 21 | tempfile=$(mktemp) 22 | cat - > $tempfile 23 | 24 | cd $QCRI_ARABIC_NORMALIZER 25 | 26 | bash qcri_normalizer_mada3.2_aramorph1.2.1.sh $tempfile 27 | cat $tempfile.mada_norm-aramorph.europarl_tok 28 | -------------------------------------------------------------------------------- /bart/examples/megatron_11b/detok.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import fileinput 9 | 10 | import sacremoses 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description="") 15 | parser.add_argument("files", nargs="*", help="input files") 16 | args = parser.parse_args() 17 | 18 | detok = sacremoses.MosesDetokenizer() 19 | 20 | for line in fileinput.input(args.files, openhook=fileinput.hook_compressed): 21 | print( 22 | detok.detokenize(line.strip().split(" ")) 23 | .replace(" @", "") 24 | .replace("@ ", "") 25 | .replace(" =", "=") 26 | .replace("= ", "=") 27 | .replace(" – ", "–") 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /bart/examples/multilingual/finetune_multilingual_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | path_2_data=$1 # which contains binarized data for each directions 4 | lang_list=$2 # 5 | lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en" 6 | # pretrained can be an mBART pretrained model as well 7 | pretrained_model=$4 # 8 | 9 | 10 | fairseq-train "$path_2_data" \ 11 | --encoder-normalize-before --decoder-normalize-before \ 12 | --arch transformer --layernorm-embedding \ 13 | --task translation_multi_simple_epoch \ 14 | --finetune-from-model "$pretrained_model" \ 15 | --sampling-method "temperature" \ 16 | --sampling-temperature "1.5" \ 17 | --encoder-langtok "src" \ 18 | --decoder-langtok \ 19 | --lang-dict "$lang_list" \ 20 | --lang-pairs "$lang_pairs" \ 21 | --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ 22 | --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ 23 | --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ 24 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ 25 | --max-tokens 1024 --update-freq 2 \ 26 | --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ 27 | --seed 222 --log-format simple --log-interval 2 28 | -------------------------------------------------------------------------------- /bart/examples/multilingual/multilingual_fairseq_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lang_pairs="en-fr,en-cs,fr-en,cs-en" 4 | path_2_data=$1 # 5 | lang_list=$2 # 6 | model=$3 # 7 | source_lang=cs 8 | target_lang=en 9 | 10 | fairseq-generate "$path_2_data" \ 11 | --path "$model" \ 12 | --task translation_multi_simple_epoch \ 13 | --gen-subset test \ 14 | --source-lang "$source_lang" \ 15 | --target-lang "$target_lang" \ 16 | --sacrebleu --remove-bpe 'sentencepiece'\ 17 | --batch-size 32 \ 18 | --encoder-langtok "src" \ 19 | --decoder-langtok \ 20 | --lang-dict "$lang_list" \ 21 | --lang-pairs "$lang_pairs" 22 | -------------------------------------------------------------------------------- /bart/examples/multilingual/train_multilingual_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | path_2_data=$1 # which contains binarized data for each directions 4 | lang_list=$2 # 5 | lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en" 6 | 7 | fairseq-train "$path_2_data" \ 8 | --encoder-normalize-before --decoder-normalize-before \ 9 | --arch transformer --layernorm-embedding \ 10 | --task translation_multi_simple_epoch \ 11 | --sampling-method "temperature" \ 12 | --sampling-temperature 1.5 \ 13 | --encoder-langtok "src" \ 14 | --decoder-langtok \ 15 | --lang-dict "$lang_list" \ 16 | --lang-pairs "$lang_pairs" \ 17 | --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ 18 | --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ 19 | --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ 20 | --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ 21 | --max-tokens 1024 --update-freq 2 \ 22 | --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ 23 | --seed 222 --log-format simple --log-interval 2 24 | -------------------------------------------------------------------------------- /bart/examples/noisychannel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .rerank_options import * # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/pointer_generator/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import transformer_pg # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/quant_noise/transformer_quantization_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # This file defines example configuration arguments for quantizing 7 | # a transformer model with product quantization 8 | 9 | # Number of Centroids for Product Quantization, by default 256 (byte-aligned) 10 | n_centroids: 11 | Linear: 12 | key: in_features 13 | value: {"*": 256} 14 | Embedding: 15 | key: embedding_dim 16 | value: {"*": 256} 17 | 18 | # Block Sizes for Product Quantization 19 | # We suggest: 8 for FFN, 4 for ATTN, 4 for embedding projections, 8 for embeddings 20 | block_sizes: 21 | Linear: 22 | key: fuzzy_name 23 | value: {fc: 8, attn: 4, emb: 4} 24 | Embedding: 25 | key: fuzzy_name 26 | value: {emb: 8} 27 | 28 | # Layers to Quantize Sequentially 29 | # We suggest: first FFN, then EMB, then ATTN 30 | layers_to_quantize: 31 | - decoder\\.layers\\.\d+\\.fc[12] 32 | - decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01] 33 | - decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj) 34 | -------------------------------------------------------------------------------- /bart/examples/roberta/commonsense_qa/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import commonsense_qa_task # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/roberta/commonsense_qa/download_cqa_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | OUTDIR=data/CommonsenseQA 8 | 9 | mkdir -p $OUTDIR 10 | 11 | wget -O $OUTDIR/train.jsonl https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl 12 | wget -O $OUTDIR/valid.jsonl https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl 13 | wget -O $OUTDIR/test.jsonl https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl 14 | wget -O $OUTDIR/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt 15 | -------------------------------------------------------------------------------- /bart/examples/roberta/wsc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import wsc_criterion # noqa 7 | from . import wsc_task # noqa 8 | -------------------------------------------------------------------------------- /bart/examples/rxf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import src # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/rxf/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import label_smoothed_cross_entropy_r3f, sentence_prediction_r3f # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import criterions, eval, models # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | for file in os.listdir(os.path.dirname(__file__)): 11 | if file.endswith(".py") and not file.startswith("_"): 12 | criterion_name = file[: file.find(".py")] 13 | importlib.import_module( 14 | "examples.simultaneous_translation.criterions." + criterion_name 15 | ) 16 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/eval/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | 11 | 12 | build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry( 13 | "--agent-type" 14 | ) 15 | 16 | 17 | DEFAULT_EOS = "" 18 | GET = 0 19 | SEND = 1 20 | 21 | for file in os.listdir(os.path.dirname(__file__)): 22 | if file.endswith(".py") and not file.startswith("_"): 23 | module = file[: file.find(".py")] 24 | importlib.import_module("agents." + module) 25 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/eval/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | 11 | 12 | (build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry( 13 | "--scorer-type" 14 | ) 15 | 16 | for file in os.listdir(os.path.dirname(__file__)): 17 | if file.endswith(".py") and not file.startswith("_"): 18 | module = file[: file.find(".py")] 19 | importlib.import_module("scorers." + module) 20 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/eval/scorers/text_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import register_scorer 7 | from .scorer import SimulScorer 8 | 9 | 10 | @register_scorer("text") 11 | class SimulTextScorer(SimulScorer): 12 | def __init__(self, args): 13 | super().__init__(args) 14 | self.data = { 15 | "src": self._load_text_file(args.src_file, split=True), 16 | "tgt": self._load_text_file(args.tgt_file, split=False), 17 | } 18 | 19 | def send_src(self, sent_id, *args): 20 | if self.steps[sent_id] >= len(self.data["src"][sent_id]): 21 | dict_to_return = { 22 | "sent_id": sent_id, 23 | "segment_id": self.steps[sent_id], 24 | "segment": self.eos, 25 | } 26 | # Consider EOS 27 | self.steps[sent_id] = len(self.data["src"][sent_id]) + 1 28 | else: 29 | dict_to_return = { 30 | "sent_id": sent_id, 31 | "segment_id": self.steps[sent_id], 32 | "segment": self.data["src"][sent_id][self.steps[sent_id]], 33 | } 34 | 35 | self.steps[sent_id] += 1 36 | 37 | return dict_to_return 38 | 39 | def src_lengths(self): 40 | # +1 for eos 41 | return [len(sent) + 1 for sent in self.data["src"]] 42 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | for file in os.listdir(os.path.dirname(__file__)): 11 | if file.endswith(".py") and not file.startswith("_"): 12 | model_name = file[: file.find(".py")] 13 | importlib.import_module( 14 | "examples.simultaneous_translation.models." + model_name 15 | ) 16 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | 11 | 12 | ( 13 | build_monotonic_attention, 14 | register_monotonic_attention, 15 | MONOTONIC_ATTENTION_REGISTRY, 16 | _, 17 | ) = registry.setup_registry("--simul-type") 18 | 19 | for file in os.listdir(os.path.dirname(__file__)): 20 | if file.endswith(".py") and not file.startswith("_"): 21 | model_name = file[: file.find(".py")] 22 | importlib.import_module( 23 | "examples.simultaneous_translation.modules." + model_name 24 | ) 25 | -------------------------------------------------------------------------------- /bart/examples/simultaneous_translation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the criterions/ directory 11 | for file in os.listdir(os.path.dirname(__file__)): 12 | if file.endswith(".py") and not file.startswith("_"): 13 | module = file[: file.find(".py")] 14 | importlib.import_module("examples.simultaneous_translation.utils." + module) 15 | -------------------------------------------------------------------------------- /bart/examples/speech_recognition/__init__.py: -------------------------------------------------------------------------------- 1 | from . import criterions, models, tasks # noqa 2 | -------------------------------------------------------------------------------- /bart/examples/speech_recognition/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | # ASG loss requires wav2letter 6 | files_to_skip = set() 7 | try: 8 | import wav2letter 9 | except ImportError: 10 | files_to_skip.add("ASG_loss.py") 11 | 12 | for file in os.listdir(os.path.dirname(__file__)): 13 | if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip: 14 | criterion_name = file[: file.find(".py")] 15 | importlib.import_module( 16 | "examples.speech_recognition.criterions." + criterion_name 17 | ) 18 | -------------------------------------------------------------------------------- /bart/examples/speech_recognition/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .asr_dataset import AsrDataset 7 | 8 | 9 | __all__ = [ 10 | "AsrDataset", 11 | ] 12 | -------------------------------------------------------------------------------- /bart/examples/speech_recognition/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | model_name = file[: file.find(".py")] 8 | importlib.import_module("examples.speech_recognition.models." + model_name) 9 | -------------------------------------------------------------------------------- /bart/examples/speech_recognition/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | task_name = file[: file.find(".py")] 8 | importlib.import_module("examples.speech_recognition.tasks." + task_name) 9 | -------------------------------------------------------------------------------- /bart/examples/translation_moe/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import translation_moe # noqa 7 | -------------------------------------------------------------------------------- /bart/examples/translation_moe/src/logsumexp_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class LogSumExpMoE(torch.autograd.Function): 10 | """Standard LogSumExp forward pass, but use *posterior* for the backward. 11 | 12 | See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" 13 | (Shen et al., 2019) `_. 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, logp, posterior, dim=-1): 18 | ctx.save_for_backward(posterior) 19 | ctx.dim = dim 20 | return torch.logsumexp(logp, dim=dim) 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | (posterior,) = ctx.saved_tensors 25 | grad_logp = grad_output.unsqueeze(ctx.dim) * posterior 26 | return grad_logp, None, None 27 | -------------------------------------------------------------------------------- /bart/examples/unsupervised_quality_estimation/aggregate_scores.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import sys 8 | 9 | import numpy as np 10 | 11 | 12 | aggregate_funcs = { 13 | "std": np.std, 14 | "var": np.var, 15 | "median": np.median, 16 | "mean": np.mean, 17 | "min": np.min, 18 | "max": np.max, 19 | } 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("-i", "--input_file", required=True, type=str) 25 | parser.add_argument("-n", "--repeat_times", required=True, type=int) 26 | parser.add_argument("-o", "--output_file", required=False) 27 | parser.add_argument("-f", "--func", required=False, default="mean") 28 | args = parser.parse_args() 29 | 30 | stream = open(args.output_file, "w") if args.output_file else sys.stdout 31 | 32 | segment_scores = [] 33 | for line in open(args.input_file): 34 | segment_scores.append(float(line.strip())) 35 | if len(segment_scores) == args.repeat_times: 36 | stream.write("{}\n".format(aggregate_funcs[args.func](segment_scores))) 37 | segment_scores = [] 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /bart/examples/unsupervised_quality_estimation/repeat_lines.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import sys 8 | 9 | 10 | def _normalize_spaces(line): 11 | return " ".join(line.split()) 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("-i", "--input_file", required=True, type=str) 17 | parser.add_argument("-n", "--repeat_times", required=True, type=int) 18 | parser.add_argument("-o", "--output_file", required=False, type=str) 19 | args = parser.parse_args() 20 | stream = open(args.output_file, "w") if args.output_file else sys.stdout 21 | 22 | for line in open(args.input_file): 23 | for _ in range(args.repeat_times): 24 | stream.write(_normalize_spaces(line) + "\n") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /bart/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = ['pdb'] 7 | __version__ = '0.9.0' 8 | 9 | import sys 10 | 11 | # backwards compatibility to support `from fairseq.meters import AverageMeter` 12 | from fairseq.logging import meters, metrics, progress_bar # noqa 13 | sys.modules['fairseq.meters'] = meters 14 | sys.modules['fairseq.metrics'] = metrics 15 | sys.modules['fairseq.progress_bar'] = progress_bar 16 | 17 | import fairseq.criterions # noqa 18 | import fairseq.models # noqa 19 | import fairseq.modules # noqa 20 | import fairseq.optim # noqa 21 | import fairseq.optim.lr_scheduler # noqa 22 | import fairseq.pdb # noqa 23 | import fairseq.scoring # noqa 24 | import fairseq.tasks # noqa 25 | import fairseq.token_generation_constraints # noqa 26 | 27 | import fairseq.benchmark # noqa 28 | import fairseq.model_parallel # noqa 29 | -------------------------------------------------------------------------------- /bart/fairseq/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/__init__.pyc -------------------------------------------------------------------------------- /bart/fairseq/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # import models/tasks to register them 7 | from . import ( # noqa 8 | dummy_lm, 9 | dummy_masked_lm, 10 | dummy_model, 11 | dummy_mt, 12 | ) 13 | -------------------------------------------------------------------------------- /bart/fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /bart/fairseq/clib/libnat_cuda/edit_dist.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | torch::Tensor LevenshteinDistanceCuda( 14 | torch::Tensor source, 15 | torch::Tensor target, 16 | torch::Tensor source_length, 17 | torch::Tensor target_length); 18 | 19 | torch::Tensor GenerateDeletionLabelCuda( 20 | torch::Tensor source, 21 | torch::Tensor operations); 22 | 23 | std::pair GenerateInsertionLabelCuda( 24 | torch::Tensor source, 25 | torch::Tensor operations); 26 | -------------------------------------------------------------------------------- /bart/fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | from argparse import Namespace 9 | from typing import Union 10 | 11 | from fairseq import registry 12 | from fairseq.criterions.fairseq_criterion import ( # noqa 13 | FairseqCriterion, 14 | LegacyFairseqCriterion, 15 | ) 16 | from omegaconf import DictConfig 17 | 18 | 19 | ( 20 | build_criterion_, 21 | register_criterion, 22 | CRITERION_REGISTRY, 23 | CRITERION_DATACLASS_REGISTRY, 24 | ) = registry.setup_registry( 25 | "--criterion", base_class=FairseqCriterion, default="cross_entropy" 26 | ) 27 | 28 | 29 | def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): 30 | return build_criterion_(criterion_cfg, task) 31 | 32 | 33 | # automatically import any Python files in the criterions/ directory 34 | for file in os.listdir(os.path.dirname(__file__)): 35 | if file.endswith(".py") and not file.startswith("_"): 36 | file_name = file[: file.find(".py")] 37 | importlib.import_module("fairseq.criterions." + file_name) 38 | -------------------------------------------------------------------------------- /bart/fairseq/data/append_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class AppendTokenDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, token=None): 15 | super().__init__(dataset) 16 | self.token = token 17 | if token is not None: 18 | self._sizes = np.array(dataset.sizes) + 1 19 | else: 20 | self._sizes = dataset.sizes 21 | 22 | def __getitem__(self, idx): 23 | item = self.dataset[idx] 24 | if self.token is not None: 25 | item = torch.cat([item, item.new([self.token])]) 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return self._sizes 31 | 32 | def num_tokens(self, index): 33 | n = self.dataset.num_tokens(index) 34 | if self.token is not None: 35 | n += 1 36 | return n 37 | 38 | def size(self, index): 39 | n = self.dataset.size(index) 40 | if self.token is not None: 41 | n += 1 42 | return n 43 | -------------------------------------------------------------------------------- /bart/fairseq/data/audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/data/audio/__init__.py -------------------------------------------------------------------------------- /bart/fairseq/data/colorize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class ColorizeDataset(BaseWrapperDataset): 12 | """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """ 13 | def __init__(self, dataset, color_getter): 14 | super().__init__(dataset) 15 | self.color_getter = color_getter 16 | 17 | def collater(self, samples): 18 | base_collate = super().collater(samples) 19 | if len(base_collate) > 0: 20 | base_collate["net_input"]["colors"] = torch.tensor( 21 | list(self.color_getter(self.dataset, s["id"]) for s in samples), 22 | dtype=torch.long, 23 | ) 24 | return base_collate 25 | -------------------------------------------------------------------------------- /bart/fairseq/data/concat_sentences_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class ConcatSentencesDataset(FairseqDataset): 12 | 13 | def __init__(self, *datasets): 14 | super().__init__() 15 | self.datasets = datasets 16 | assert all(len(ds) == len(datasets[0]) for ds in datasets), \ 17 | 'datasets must have the same length' 18 | 19 | def __getitem__(self, index): 20 | return torch.cat([ds[index] for ds in self.datasets]) 21 | 22 | def __len__(self): 23 | return len(self.datasets[0]) 24 | 25 | def collater(self, samples): 26 | return self.datasets[0].collater(samples) 27 | 28 | @property 29 | def sizes(self): 30 | return sum(ds.sizes for ds in self.datasets) 31 | 32 | def num_tokens(self, index): 33 | return sum(ds.num_tokens(index) for ds in self.datasets) 34 | 35 | def size(self, index): 36 | return sum(ds.size(index) for ds in self.datasets) 37 | 38 | def ordered_indices(self): 39 | return self.datasets[0].ordered_indices() 40 | 41 | @property 42 | def supports_prefetch(self): 43 | return any( 44 | getattr(ds, 'supports_prefetch', False) for ds in self.datasets 45 | ) 46 | 47 | def prefetch(self, indices): 48 | for ds in self.datasets: 49 | if getattr(ds, 'supports_prefetch', False): 50 | ds.prefetch(indices) 51 | 52 | def set_epoch(self, epoch): 53 | super().set_epoch(epoch) 54 | for ds in self.datasets: 55 | if hasattr(ds, 'set_epoch'): 56 | ds.set_epoch(epoch) 57 | -------------------------------------------------------------------------------- /bart/fairseq/data/data_utils_fast.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/data/data_utils_fast.cpython-36m-darwin.so -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( 14 | '--tokenizer', 15 | default=None, 16 | ) 17 | 18 | 19 | build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( 20 | '--bpe', 21 | default=None, 22 | ) 23 | 24 | 25 | # automatically import any Python files in the encoders/ directory 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith('.py') and not file.startswith('_'): 28 | module = file[:file.find('.py')] 29 | importlib.import_module('fairseq.data.encoders.' + module) 30 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/byte_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq import file_utils 8 | from fairseq.data.encoders import register_bpe 9 | from fairseq.data.encoders.byte_utils import (byte_encode, smart_byte_decode, 10 | SPACE, SPACE_ESCAPE) 11 | 12 | 13 | @register_bpe('byte_bpe') 14 | class ByteBPE(object): 15 | @staticmethod 16 | def add_args(parser): 17 | # fmt: off 18 | parser.add_argument('--sentencepiece-model-path', type=str, 19 | help='path to sentencepiece model') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | vocab = file_utils.cached_path(args.sentencepiece_model_path) 24 | try: 25 | import sentencepiece as spm 26 | self.sp = spm.SentencePieceProcessor() 27 | self.sp.Load(vocab) 28 | except ImportError: 29 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 30 | 31 | def encode(self, x: str) -> str: 32 | byte_encoded = byte_encode(x) 33 | return SPACE.join(self.sp.EncodeAsPieces(byte_encoded)) 34 | 35 | @staticmethod 36 | def decode(x: str) -> str: 37 | unescaped = x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE) 38 | return smart_byte_decode(unescaped) 39 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/bytes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq.data.encoders import register_bpe 8 | from fairseq.data.encoders.byte_utils import (byte_encode, smart_byte_decode, 9 | SPACE, SPACE_ESCAPE) 10 | 11 | 12 | @register_bpe('bytes') 13 | class Bytes(object): 14 | def __init__(self, args): 15 | pass 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | pass 20 | 21 | @staticmethod 22 | def encode(x: str) -> str: 23 | encoded = byte_encode(x) 24 | escaped = encoded.replace(SPACE, SPACE_ESCAPE) 25 | return SPACE.join(list(escaped)) 26 | 27 | @staticmethod 28 | def decode(x: str) -> str: 29 | unescaped = x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE) 30 | return smart_byte_decode(unescaped) 31 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/characters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from fairseq.data.encoders import register_bpe 8 | 9 | SPACE = chr(32) 10 | SPACE_ESCAPE = chr(9601) 11 | 12 | 13 | @register_bpe('characters') 14 | class Characters(object): 15 | def __init__(self, args): 16 | pass 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | pass 21 | 22 | @staticmethod 23 | def encode(x: str) -> str: 24 | escaped = x.replace(SPACE, SPACE_ESCAPE) 25 | return SPACE.join(list(escaped)) 26 | 27 | @staticmethod 28 | def decode(x: str) -> str: 29 | return x.replace(SPACE, '').replace(SPACE_ESCAPE, SPACE) 30 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/fastbpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('fastbpe') 11 | class fastBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--bpe-codes', type=str, 17 | help='path to fastBPE BPE') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | if args.bpe_codes is None: 22 | raise ValueError('--bpe-codes is required for --bpe=fastbpe') 23 | codes = file_utils.cached_path(args.bpe_codes) 24 | try: 25 | import fastBPE 26 | self.bpe = fastBPE.fastBPE(codes) 27 | self.bpe_symbol = "@@ " 28 | except ImportError: 29 | raise ImportError('Please install fastBPE with: pip install fastBPE') 30 | 31 | def encode(self, x: str) -> str: 32 | return self.bpe.apply([x])[0] 33 | 34 | def decode(self, x: str) -> str: 35 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 36 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/hf_bert_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_bpe 7 | 8 | 9 | @register_bpe('bert') 10 | class BertBPE(object): 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--bpe-cased', action='store_true', 16 | help='set for cased BPE', 17 | default=False) 18 | parser.add_argument('--bpe-vocab-file', type=str, 19 | help='bpe vocab file.') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | try: 24 | from transformers import BertTokenizer 25 | except ImportError: 26 | raise ImportError( 27 | 'Please install transformers with: pip install transformers' 28 | ) 29 | 30 | if 'bpe_vocab_file' in args: 31 | self.bert_tokenizer = BertTokenizer( 32 | args.bpe_vocab_file, 33 | do_lower_case=not args.bpe_cased 34 | ) 35 | else: 36 | vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased' 37 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) 38 | 39 | def encode(self, x: str) -> str: 40 | return ' '.join(self.bert_tokenizer.tokenize(x)) 41 | 42 | def decode(self, x: str) -> str: 43 | return self.bert_tokenizer.clean_up_tokenization( 44 | self.bert_tokenizer.convert_tokens_to_string(x.split(' ')) 45 | ) 46 | 47 | def is_beginning_of_word(self, x: str) -> bool: 48 | return not x.startswith('##') 49 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/hf_byte_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_bpe 7 | 8 | 9 | @register_bpe('hf_byte_bpe') 10 | class HuggingFaceByteLevelBPE(object): 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--bpe-merges', help='path to merges.txt') 16 | parser.add_argument('--bpe-vocab', help='path to vocab.json') 17 | parser.add_argument('--bpe-add-prefix-space', action='store_true', 18 | help='add prefix space before encoding') 19 | # fmt: on 20 | 21 | def __init__(self, args): 22 | try: 23 | from tokenizers import ByteLevelBPETokenizer 24 | except ImportError: 25 | raise ImportError( 26 | 'Please install huggingface/tokenizers with: ' 27 | 'pip install tokenizers' 28 | ) 29 | 30 | self.bpe = ByteLevelBPETokenizer( 31 | args.bpe_vocab, 32 | args.bpe_merges, 33 | add_prefix_space=getattr(args, 'bpe_add_prefix_space', False), 34 | ) 35 | 36 | def encode(self, x: str) -> str: 37 | return ' '.join(map(str, self.bpe.encode(x).ids)) 38 | 39 | def decode(self, x: str) -> str: 40 | return self.bpe.decode([ 41 | int(tok) if tok not in {'', ''} else tok 42 | for tok in x.split() 43 | ]) 44 | 45 | def is_beginning_of_word(self, x: str) -> bool: 46 | return self.decode(x).startswith(' ') 47 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer('nltk') 10 | class NLTKTokenizer(object): 11 | 12 | def __init__(self, source_lang=None, target_lang=None): 13 | try: 14 | from nltk.tokenize import word_tokenize 15 | self.word_tokenize = word_tokenize 16 | except ImportError: 17 | raise ImportError('Please install nltk with: pip install nltk') 18 | 19 | def encode(self, x: str) -> str: 20 | return ' '.join(self.word_tokenize(x)) 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('sentencepiece') 11 | class SentencepieceBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--sentencepiece-model', type=str, 17 | help='path to sentencepiece model') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) 22 | try: 23 | import sentencepiece as spm 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(sentencepiece_model) 26 | except ImportError: 27 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 28 | 29 | def encode(self, x: str) -> str: 30 | return ' '.join(self.sp.EncodeAsPieces(x)) 31 | 32 | def decode(self, x: str) -> str: 33 | return x.replace(' ', '').replace('\u2581', ' ').strip() 34 | 35 | def is_beginning_of_word(self, x: str) -> bool: 36 | if x in ['', '', '', '']: 37 | # special elements are always considered beginnings 38 | # HACK: this logic is already present in fairseq/tasks/masked_lm.py 39 | # but these special tokens are also contained in the sentencepiece 40 | # vocabulary which causes duplicate special tokens. This hack makes 41 | # sure that they are all taken into account. 42 | return True 43 | return x.startswith('\u2581') 44 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | 10 | 11 | @register_tokenizer('space') 12 | class SpaceTokenizer(object): 13 | 14 | def __init__(self, source_lang=None, target_lang=None): 15 | self.space_tok = re.compile(r"\s+") 16 | 17 | def encode(self, x: str) -> str: 18 | return self.space_tok.sub(' ', x) 19 | 20 | def decode(self, x: str) -> str: 21 | return x 22 | -------------------------------------------------------------------------------- /bart/fairseq/data/encoders/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from fairseq.data import encoders 8 | 9 | 10 | def get_whole_word_mask(args, dictionary): 11 | bpe = encoders.build_bpe(args) 12 | if bpe is not None: 13 | def is_beginning_of_word(i): 14 | if i < dictionary.nspecial: 15 | # special elements are always considered beginnings 16 | return True 17 | tok = dictionary[i] 18 | if tok.startswith('madeupword'): 19 | return True 20 | try: 21 | return bpe.is_beginning_of_word(tok) 22 | except ValueError: 23 | return True 24 | mask_whole_words = torch.ByteTensor(list( 25 | map(is_beginning_of_word, range(len(dictionary))) 26 | )) 27 | return mask_whole_words 28 | return None 29 | -------------------------------------------------------------------------------- /bart/fairseq/data/id_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class IdDataset(FairseqDataset): 12 | 13 | def __getitem__(self, index): 14 | return index 15 | 16 | def __len__(self): 17 | return 0 18 | 19 | def collater(self, samples): 20 | return torch.tensor(samples) 21 | -------------------------------------------------------------------------------- /bart/fairseq/data/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary 7 | from .block_pair_dataset import BlockPairDataset 8 | from .masked_lm_dataset import MaskedLMDataset 9 | 10 | __all__ = [ 11 | 'BertDictionary', 12 | 'BlockPairDataset', 13 | 'MaskedLMDataset', 14 | 'MaskedLMDictionary', 15 | ] 16 | -------------------------------------------------------------------------------- /bart/fairseq/data/legacy/masked_lm_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import Dictionary 7 | 8 | 9 | class MaskedLMDictionary(Dictionary): 10 | """ 11 | Dictionary for Masked Language Modelling tasks. This extends Dictionary by 12 | adding the mask symbol. 13 | """ 14 | def __init__( 15 | self, 16 | pad='', 17 | eos='', 18 | unk='', 19 | mask='', 20 | ): 21 | super().__init__(pad=pad, eos=eos, unk=unk) 22 | self.mask_word = mask 23 | self.mask_index = self.add_symbol(mask) 24 | self.nspecial = len(self.symbols) 25 | 26 | def mask(self): 27 | """Helper to get index of mask symbol""" 28 | return self.mask_index 29 | 30 | 31 | class BertDictionary(MaskedLMDictionary): 32 | """ 33 | Dictionary for BERT task. This extends MaskedLMDictionary by adding support 34 | for cls and sep symbols. 35 | """ 36 | def __init__( 37 | self, 38 | pad='', 39 | eos='', 40 | unk='', 41 | mask='', 42 | cls='', 43 | sep='' 44 | ): 45 | super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) 46 | self.cls_word = cls 47 | self.sep_word = sep 48 | self.cls_index = self.add_symbol(cls) 49 | self.sep_index = self.add_symbol(sep) 50 | self.nspecial = len(self.symbols) 51 | 52 | def cls(self): 53 | """Helper to get index of cls symbol""" 54 | return self.cls_index 55 | 56 | def sep(self): 57 | """Helper to get index of sep symbol""" 58 | return self.sep_index 59 | -------------------------------------------------------------------------------- /bart/fairseq/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ListDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, sizes=None): 12 | super().__init__(dataset) 13 | self._sizes = sizes 14 | 15 | def __iter__(self): 16 | for x in self.dataset: 17 | yield x 18 | 19 | def collater(self, samples): 20 | return samples 21 | 22 | @property 23 | def sizes(self): 24 | return self._sizes 25 | 26 | def num_tokens(self, index): 27 | return self.sizes[index] 28 | 29 | def size(self, index): 30 | return self.sizes[index] 31 | 32 | def set_epoch(self, epoch): 33 | pass 34 | -------------------------------------------------------------------------------- /bart/fairseq/data/lru_cache_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import lru_cache 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class LRUCacheDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, token=None): 14 | super().__init__(dataset) 15 | 16 | @lru_cache(maxsize=8) 17 | def __getitem__(self, index): 18 | return self.dataset[index] 19 | 20 | @lru_cache(maxsize=8) 21 | def collater(self, samples): 22 | return self.dataset.collater(samples) 23 | -------------------------------------------------------------------------------- /bart/fairseq/data/multilingual/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /bart/fairseq/data/num_samples_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqDataset 7 | 8 | 9 | class NumSamplesDataset(FairseqDataset): 10 | 11 | def __getitem__(self, index): 12 | return 1 13 | 14 | def __len__(self): 15 | return 0 16 | 17 | def collater(self, samples): 18 | return sum(samples) 19 | -------------------------------------------------------------------------------- /bart/fairseq/data/numel_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class NumelDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, reduce=False): 15 | super().__init__(dataset) 16 | self.reduce = reduce 17 | 18 | def __getitem__(self, index): 19 | item = self.dataset[index] 20 | if torch.is_tensor(item): 21 | return torch.numel(item) 22 | else: 23 | return np.size(item) 24 | 25 | def __len__(self): 26 | return len(self.dataset) 27 | 28 | def collater(self, samples): 29 | if self.reduce: 30 | return sum(samples) 31 | else: 32 | return torch.tensor(samples) 33 | -------------------------------------------------------------------------------- /bart/fairseq/data/offset_tokens_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class OffsetTokensDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, offset): 12 | super().__init__(dataset) 13 | self.offset = offset 14 | 15 | def __getitem__(self, idx): 16 | return self.dataset[idx] + self.offset 17 | -------------------------------------------------------------------------------- /bart/fairseq/data/pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data import data_utils 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class PadDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, pad_idx, left_pad): 14 | super().__init__(dataset) 15 | self.pad_idx = pad_idx 16 | self.left_pad = left_pad 17 | 18 | def collater(self, samples): 19 | return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) 20 | 21 | 22 | class LeftPadDataset(PadDataset): 23 | 24 | def __init__(self, dataset, pad_idx): 25 | super().__init__(dataset, pad_idx, left_pad=True) 26 | 27 | 28 | class RightPadDataset(PadDataset): 29 | 30 | def __init__(self, dataset, pad_idx): 31 | super().__init__(dataset, pad_idx, left_pad=False) 32 | -------------------------------------------------------------------------------- /bart/fairseq/data/prepend_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): 14 | super().__init__(dataset) 15 | self.prepend_getter = prepend_getter 16 | self.ensure_first_token = ensure_first_token_is 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataset[idx] 20 | is_tuple = isinstance(item, tuple) 21 | src = item[0] if is_tuple else item 22 | 23 | assert self.ensure_first_token is None or src[0] == self.ensure_first_token 24 | prepend_idx = self.prepend_getter(self.dataset, idx) 25 | assert isinstance(prepend_idx, int) 26 | src[0] = prepend_idx 27 | item = tuple((src,) + item[1:]) if is_tuple else src 28 | return item 29 | -------------------------------------------------------------------------------- /bart/fairseq/data/prepend_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class PrependTokenDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, token=None): 15 | super().__init__(dataset) 16 | self.token = token 17 | if token is not None: 18 | self._sizes = np.array(dataset.sizes) + 1 19 | else: 20 | self._sizes = dataset.sizes 21 | 22 | def __getitem__(self, idx): 23 | item = self.dataset[idx] 24 | if self.token is not None: 25 | item = torch.cat([item.new([self.token]), item]) 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return self._sizes 31 | 32 | def num_tokens(self, index): 33 | n = self.dataset.num_tokens(index) 34 | if self.token is not None: 35 | n += 1 36 | return n 37 | 38 | def size(self, index): 39 | n = self.dataset.size(index) 40 | if self.token is not None: 41 | n += 1 42 | return n 43 | -------------------------------------------------------------------------------- /bart/fairseq/data/raw_label_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class RawLabelDataset(FairseqDataset): 12 | 13 | def __init__(self, labels): 14 | super().__init__() 15 | self.labels = labels 16 | 17 | def __getitem__(self, index): 18 | return self.labels[index] 19 | 20 | def __len__(self): 21 | return len(self.labels) 22 | 23 | def collater(self, samples): 24 | return torch.tensor(samples) 25 | -------------------------------------------------------------------------------- /bart/fairseq/data/replace_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class ReplaceDataset(BaseWrapperDataset): 10 | """Replaces tokens found in the dataset by a specified replacement token 11 | 12 | Args: 13 | dataset (~torch.utils.data.Dataset): dataset to replace tokens in 14 | replace_map(Dictionary[int,int]): map of token to replace -> replacement token 15 | offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be 16 | as many as the number of objects returned by the underlying dataset __getitem__ method. 17 | """ 18 | 19 | def __init__(self, dataset, replace_map, offsets): 20 | super().__init__(dataset) 21 | assert len(replace_map) > 0 22 | self.replace_map = replace_map 23 | self.offsets = offsets 24 | 25 | def __getitem__(self, index): 26 | item = self.dataset[index] 27 | is_tuple = isinstance(item, tuple) 28 | srcs = item if is_tuple else [item] 29 | 30 | for offset, src in zip(self.offsets, srcs): 31 | for k, v in self.replace_map.items(): 32 | src_off = src[offset:] if offset >= 0 else src[:offset] 33 | src_off.masked_fill_(src_off == k, v) 34 | 35 | item = srcs if is_tuple else srcs[0] 36 | return item 37 | -------------------------------------------------------------------------------- /bart/fairseq/data/roll_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class RollDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, shifts): 14 | super().__init__(dataset) 15 | self.shifts = shifts 16 | 17 | def __getitem__(self, index): 18 | item = self.dataset[index] 19 | return torch.roll(item, self.shifts) 20 | -------------------------------------------------------------------------------- /bart/fairseq/data/sort_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class SortDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, sort_order): 14 | super().__init__(dataset) 15 | if not isinstance(sort_order, (list, tuple)): 16 | sort_order = [sort_order] 17 | self.sort_order = sort_order 18 | 19 | assert all(len(so) == len(dataset) for so in sort_order) 20 | 21 | def ordered_indices(self): 22 | return np.lexsort(self.sort_order) 23 | -------------------------------------------------------------------------------- /bart/fairseq/data/strip_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class StripTokenDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, id_to_strip): 12 | super().__init__(dataset) 13 | self.id_to_strip = id_to_strip 14 | 15 | def __getitem__(self, index): 16 | item = self.dataset[index] 17 | while len(item) > 0 and item[-1] == self.id_to_strip: 18 | item = item[:-1] 19 | while len(item) > 0 and item[0] == self.id_to_strip: 20 | item = item[1:] 21 | return item 22 | -------------------------------------------------------------------------------- /bart/fairseq/data/token_block_utils_fast.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/data/token_block_utils_fast.cpython-36m-darwin.so -------------------------------------------------------------------------------- /bart/fairseq/dataclass/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import ChoiceEnum, FairseqDataclass 7 | 8 | 9 | __all__ = ["FairseqDataclass", "ChoiceEnum"] 10 | -------------------------------------------------------------------------------- /bart/fairseq/dataclass/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.dataclass.utils import ChoiceEnum 7 | 8 | 9 | LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) 10 | DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) 11 | DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) 12 | ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) 13 | PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) 14 | -------------------------------------------------------------------------------- /bart/fairseq/libbleu.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/libbleu.cpython-36m-darwin.so -------------------------------------------------------------------------------- /bart/fairseq/libnat.cpython-36m-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/libnat.cpython-36m-darwin.so -------------------------------------------------------------------------------- /bart/fairseq/logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/logging/__init__.py -------------------------------------------------------------------------------- /bart/fairseq/logging/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/logging/__init__.pyc -------------------------------------------------------------------------------- /bart/fairseq/model_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import criterions, modules, models # noqa 7 | -------------------------------------------------------------------------------- /bart/fairseq/model_parallel/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the criterions/ directory 11 | for file in os.listdir(os.path.dirname(__file__)): 12 | if file.endswith('.py') and not file.startswith('_'): 13 | module = file[:file.find('.py')] 14 | importlib.import_module('fairseq.model_parallel.criterions.' + module) 15 | -------------------------------------------------------------------------------- /bart/fairseq/model_parallel/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the models/ directory 11 | models_dir = os.path.dirname(__file__) 12 | for file in os.listdir(models_dir): 13 | path = os.path.join(models_dir, file) 14 | if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): 15 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 16 | module = importlib.import_module('fairseq.model_parallel.models.' + model_name) 17 | -------------------------------------------------------------------------------- /bart/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import * # noqa 7 | -------------------------------------------------------------------------------- /bart/fairseq/model_parallel/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import * # noqa 7 | -------------------------------------------------------------------------------- /bart/fairseq/model_parallel/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .multihead_attention import ModelParallelMultiheadAttention 7 | from .transformer_layer import ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer 8 | from .transformer_sentence_encoder_layer import ModelParallelTransformerSentenceEncoderLayer 9 | from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder 10 | 11 | __all__ = [ 12 | 'ModelParallelMultiheadAttention', 13 | 'ModelParallelTransformerEncoderLayer', 14 | 'ModelParallelTransformerDecoderLayer', 15 | 'ModelParallelTransformerSentenceEncoder', 16 | 'ModelParallelTransformerSentenceEncoderLayer', 17 | ] 18 | -------------------------------------------------------------------------------- /bart/fairseq/models/bart/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | -------------------------------------------------------------------------------- /bart/fairseq/models/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | 10 | # automatically import any Python files in the models/huggingface/ directory 11 | models_dir = os.path.dirname(__file__) 12 | for file in os.listdir(models_dir): 13 | path = os.path.join(models_dir, file) 14 | if ( 15 | not file.startswith('_') 16 | and not file.startswith('.') 17 | and (file.endswith('.py') or os.path.isdir(path)) 18 | ): 19 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 20 | module = importlib.import_module('fairseq.models.huggingface.' + model_name) 21 | -------------------------------------------------------------------------------- /bart/fairseq/models/nat/__init__.py: -------------------------------------------------------------------------------- 1 | from .fairseq_nat_model import * 2 | from .nonautoregressive_transformer import * 3 | from .nat_crf_transformer import * 4 | from .iterative_nonautoregressive_transformer import * 5 | from .cmlm_transformer import * 6 | from .levenshtein_transformer import * 7 | from .insertion_transformer import * 8 | -------------------------------------------------------------------------------- /bart/fairseq/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hub_interface import * # noqa 7 | from .model import * # noqa 8 | from .model_camembert import * # noqa 9 | from .model_xlmr import * # noqa 10 | -------------------------------------------------------------------------------- /bart/fairseq/models/roberta/model_xlmr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Unsupervised Cross-lingual Representation Learning at Scale 7 | """ 8 | 9 | from fairseq.models import register_model 10 | 11 | from .hub_interface import RobertaHubInterface 12 | from .model import RobertaModel 13 | 14 | 15 | @register_model('xlmr') 16 | class XLMRModel(RobertaModel): 17 | 18 | @classmethod 19 | def hub_models(cls): 20 | return { 21 | 'xlmr.base': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz', 22 | 'xlmr.large': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz', 23 | } 24 | 25 | @classmethod 26 | def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs): 27 | from fairseq import hub_utils 28 | x = hub_utils.from_pretrained( 29 | model_name_or_path, 30 | checkpoint_file, 31 | data_name_or_path, 32 | archive_map=cls.hub_models(), 33 | bpe=bpe, 34 | load_checkpoint_heads=True, 35 | **kwargs, 36 | ) 37 | return RobertaHubInterface(x['args'], x['task'], x['models'][0]) 38 | -------------------------------------------------------------------------------- /bart/fairseq/models/wav2vec/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .wav2vec import * # noqa 7 | from .wav2vec2 import * # noqa 8 | from .wav2vec2_asr import * # noqa 9 | -------------------------------------------------------------------------------- /bart/fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.nn.modules.utils import _single 8 | 9 | 10 | class ConvTBC(torch.nn.Module): 11 | """1D convolution over an input of shape (time x batch x channel) 12 | 13 | The implementation uses gemm to perform the convolution. This implementation 14 | is faster than cuDNN for small kernel sizes. 15 | """ 16 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 17 | super(ConvTBC, self).__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.kernel_size = _single(kernel_size) 21 | self.padding = _single(padding) 22 | 23 | self.weight = torch.nn.Parameter(torch.Tensor( 24 | self.kernel_size[0], in_channels, out_channels)) 25 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 26 | 27 | def forward(self, input): 28 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding[0]) 29 | 30 | def __repr__(self): 31 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 32 | ', padding={padding}') 33 | if self.bias is None: 34 | s += ', bias=False' 35 | s += ')' 36 | return s.format(name=self.__class__.__name__, **self.__dict__) 37 | -------------------------------------------------------------------------------- /bart/fairseq/modules/dynamicconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamicconv_layer import DynamicconvLayer # noqa 7 | -------------------------------------------------------------------------------- /bart/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector dynamicconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector dynamicconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector dynamicconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return dynamicconv_cuda_forward(input, filters, 36 | padding_l); 37 | } 38 | 39 | std::vector dynamicconv_backward( 40 | at::Tensor gradOutput, 41 | int padding_l, 42 | at::Tensor input, 43 | at::Tensor filters) { 44 | 45 | CHECK_INPUT(gradOutput); 46 | CHECK_INPUT(input); 47 | CHECK_INPUT(filters); 48 | 49 | return dynamicconv_cuda_backward(gradOutput, padding_l, 50 | input, filters); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); 55 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); 56 | } 57 | -------------------------------------------------------------------------------- /bart/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #define SHFL_MASK 0xffffffff 27 | 28 | template 29 | __global__ 30 | void dynamicconv_forward_kernel(const scalar_t* input, 31 | const scalar_t* weight, 32 | int minibatch, 33 | int sequenceLength, 34 | int numFeatures, 35 | int numFiltersInBlock, 36 | int numHeads, 37 | scalar_t* output); 38 | 39 | template 40 | __global__ 41 | void dynamicconv_backward_kernel( 42 | const scalar_t* gradOutput, // B * C * T 43 | const scalar_t* input, // B * C * T 44 | const scalar_t* weight, 45 | int minibatch, 46 | int sequenceLength, 47 | int numFeatures, 48 | int numFiltersInBlock, 49 | int numHeads, 50 | scalar_t* gradWeight, 51 | scalar_t* gradInput); // B * H * k * T 52 | -------------------------------------------------------------------------------- /bart/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector dynamicconv_cpu_forward( 5 | float* input, 6 | float* filters, 7 | int padding_l); 8 | 9 | std::vector dynamicconv_cpu_backward( 10 | float* gradOutput, 11 | int padding_l, 12 | float* input, 13 | float* filters); 14 | 15 | std::vector dynamicconv_forward( 16 | float* input, 17 | float* filters, 18 | int padding_l) { 19 | 20 | return dynamicconv_cpu_forward(input, filters, padding_l); 21 | } 22 | 23 | std::vector dynamicconv_backward( 24 | float* gradOutput, 25 | int padding_l, 26 | float* input, 27 | float* filters) { 28 | 29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); 34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); 35 | } 36 | -------------------------------------------------------------------------------- /bart/fairseq/modules/dynamicconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='dynamicconv_layer', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='dynamicconv_cuda', 15 | sources=[ 16 | 'dynamicconv_cuda.cpp', 17 | 'dynamicconv_cuda_kernel.cu', 18 | ], 19 | ), 20 | ], 21 | cmdclass={ 22 | 'build_ext': BuildExtension 23 | }) 24 | -------------------------------------------------------------------------------- /bart/fairseq/modules/fp32_group_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Layer norm done in fp32 (for fp16 training) 7 | """ 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Fp32GroupNorm(nn.GroupNorm): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | def forward(self, input): 18 | output = F.group_norm( 19 | input.float(), 20 | self.num_groups, 21 | self.weight.float() if self.weight is not None else None, 22 | self.bias.float() if self.bias is not None else None, 23 | self.eps, 24 | ) 25 | return output.type_as(input) 26 | -------------------------------------------------------------------------------- /bart/fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def gelu_accurate(x): 17 | if not hasattr(gelu_accurate, "_a"): 18 | gelu_accurate._a = math.sqrt(2 / math.pi) 19 | return ( 20 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 21 | ) 22 | 23 | 24 | def gelu(x: torch.Tensor) -> torch.Tensor: 25 | return torch.nn.functional.gelu(x.float()).type_as(x) 26 | -------------------------------------------------------------------------------- /bart/fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class GradMultiply(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.scale = scale 13 | res = x.new(x) 14 | return res 15 | 16 | @staticmethod 17 | def backward(ctx, grad): 18 | return grad * ctx.scale, None 19 | -------------------------------------------------------------------------------- /bart/fairseq/modules/layer_drop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | LayerDrop as described in https://arxiv.org/abs/1909.11556. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class LayerDropModuleList(nn.ModuleList): 14 | """ 15 | A LayerDrop implementation based on :class:`torch.nn.ModuleList`. 16 | 17 | We refresh the choice of which layers to drop every time we iterate 18 | over the LayerDropModuleList instance. During evaluation we always 19 | iterate over all layers. 20 | 21 | Usage:: 22 | 23 | layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) 24 | for layer in layers: # this might iterate over layers 1 and 3 25 | x = layer(x) 26 | for layer in layers: # this might iterate over all layers 27 | x = layer(x) 28 | for layer in layers: # this might not iterate over any layers 29 | x = layer(x) 30 | 31 | Args: 32 | p (float): probability of dropping out each layer 33 | modules (iterable, optional): an iterable of modules to add 34 | """ 35 | 36 | def __init__(self, p, modules=None): 37 | super().__init__(modules) 38 | self.p = p 39 | 40 | def __iter__(self): 41 | dropout_probs = torch.empty(len(self)).uniform_() 42 | for i, m in enumerate(super().__iter__()): 43 | if not self.training or (dropout_probs[i] > self.p): 44 | yield m 45 | -------------------------------------------------------------------------------- /bart/fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | try: 12 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm 13 | 14 | has_fused_layernorm = True 15 | 16 | class FusedLayerNorm(_FusedLayerNorm): 17 | @torch.jit.unused 18 | def forward(self, x): 19 | if not x.is_cuda: 20 | return super().forward(x) 21 | else: 22 | with torch.cuda.device(x.device): 23 | return super().forward(x) 24 | 25 | except ImportError: 26 | has_fused_layernorm = False 27 | 28 | 29 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 30 | if torch.jit.is_scripting(): 31 | export = True 32 | if not export and torch.cuda.is_available() and has_fused_layernorm: 33 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 34 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 35 | 36 | 37 | class Fp32LayerNorm(nn.LayerNorm): 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | 41 | def forward(self, input): 42 | output = F.layer_norm( 43 | input.float(), 44 | self.normalized_shape, 45 | self.weight.float() if self.weight is not None else None, 46 | self.bias.float() if self.bias is not None else None, 47 | self.eps, 48 | ) 49 | return output.type_as(input) 50 | -------------------------------------------------------------------------------- /bart/fairseq/modules/lightconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lightconv_layer import LightconvLayer # noqa 7 | -------------------------------------------------------------------------------- /bart/fairseq/modules/lightconv_layer/lightconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | std::vector lightconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector lightconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector lightconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return lightconv_cuda_forward(input, filters, padding_l); 36 | } 37 | 38 | std::vector lightconv_backward( 39 | at::Tensor gradOutput, 40 | int padding_l, 41 | at::Tensor input, 42 | at::Tensor filters) { 43 | 44 | CHECK_INPUT(gradOutput); 45 | CHECK_INPUT(input); 46 | CHECK_INPUT(filters); 47 | 48 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); 53 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /bart/fairseq/modules/lightconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='lightconv_layer', 12 | ext_modules=[ 13 | CUDAExtension('lightconv_cuda', [ 14 | 'lightconv_cuda.cpp', 15 | 'lightconv_cuda_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }) 21 | -------------------------------------------------------------------------------- /bart/fairseq/modules/positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | from .learned_positional_embedding import LearnedPositionalEmbedding 8 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 9 | 10 | 11 | def PositionalEmbedding( 12 | num_embeddings: int, 13 | embedding_dim: int, 14 | padding_idx: int, 15 | learned: bool = False, 16 | ): 17 | if learned: 18 | # if padding_idx is specified then offset the embedding ids by 19 | # this index and adjust num_embeddings appropriately 20 | # TODO: The right place for this offset would be inside 21 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation. 22 | if padding_idx is not None: 23 | num_embeddings = num_embeddings + padding_idx + 1 24 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 25 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 26 | if padding_idx is not None: 27 | nn.init.constant_(m.weight[padding_idx], 0) 28 | else: 29 | m = SinusoidalPositionalEmbedding( 30 | embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, 31 | ) 32 | return m 33 | -------------------------------------------------------------------------------- /bart/fairseq/modules/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq/modules/quantization/__init__.py -------------------------------------------------------------------------------- /bart/fairseq/modules/quantization/pq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import SizeTracker, quantize_model_ # NOQA 7 | -------------------------------------------------------------------------------- /bart/fairseq/modules/quantization/pq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qconv import PQConv2d # NOQA 7 | from .qlinear import PQLinear # NOQA 8 | from .qemb import PQEmbedding # NOQA 9 | -------------------------------------------------------------------------------- /bart/fairseq/modules/quantization/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .utils import quantize_model_ # NOQA 7 | -------------------------------------------------------------------------------- /bart/fairseq/modules/quantization/scalar/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qconv import IntConv2d # NOQA 7 | from .qlinear import IntLinear # NOQA 8 | from .qemb import IntEmbedding # NOQA 9 | from .qact import ActivationQuantizer # NOQA 10 | -------------------------------------------------------------------------------- /bart/fairseq/modules/same_pad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from torch import nn 8 | 9 | 10 | class SamePad(nn.Module): 11 | def __init__(self, kernel_size): 12 | super().__init__() 13 | self.remove = kernel_size % 2 == 0 14 | 15 | def forward(self, x): 16 | if self.remove: 17 | x = x[:, :, :-1] 18 | return x 19 | -------------------------------------------------------------------------------- /bart/fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import torch 8 | 9 | 10 | class ScalarBias(torch.autograd.Function): 11 | """ 12 | Adds a vector of scalars, used in self-attention mechanism to allow 13 | the model to optionally attend to this vector instead of the past 14 | """ 15 | 16 | @staticmethod 17 | def forward(ctx, input, dim, bias_init): 18 | size = list(input.size()) 19 | size[dim] += 1 20 | output = input.new(*size).fill_(bias_init) 21 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 22 | ctx.dim = dim 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad): 27 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 28 | 29 | 30 | def scalar_bias(input, dim, bias_init=0): 31 | return ScalarBias.apply(input, dim, bias_init) 32 | -------------------------------------------------------------------------------- /bart/fairseq/modules/sparse_transformer_sentence_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.modules import TransformerSentenceEncoderLayer 7 | from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention 8 | 9 | 10 | class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): 11 | """ 12 | Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) 13 | """ 14 | 15 | def __init__( 16 | self, 17 | embedding_dim: int = 768, 18 | ffn_embedding_dim: int = 3072, 19 | num_attention_heads: int = 8, 20 | dropout: float = 0.1, 21 | attention_dropout: float = 0.1, 22 | activation_dropout: float = 0.1, 23 | activation_fn: str = 'relu', 24 | export: bool = False, 25 | is_bidirectional: bool = True, 26 | stride: int = 32, 27 | expressivity: int = 8, 28 | ) -> None: 29 | 30 | super().__init__( 31 | embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, 32 | attention_dropout, activation_dropout, activation_fn, export 33 | ) 34 | 35 | self.self_attn = SparseMultiheadAttention( 36 | self.embedding_dim, 37 | num_attention_heads, 38 | dropout=attention_dropout, 39 | add_bias_kv=False, 40 | add_zero_attn=False, 41 | self_attention=True, 42 | is_bidirectional=is_bidirectional, 43 | stride=stride, 44 | expressivity=expressivity, 45 | ) 46 | -------------------------------------------------------------------------------- /bart/fairseq/modules/transpose_last.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | transpose last 2 dimensions of the input 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | 12 | class TransposeLast(nn.Module): 13 | def __init__(self, deconstruct_idx=None): 14 | super().__init__() 15 | self.deconstruct_idx = deconstruct_idx 16 | 17 | def forward(self, x): 18 | if self.deconstruct_idx is not None: 19 | x = x[self.deconstruct_idx] 20 | return x.transpose(-2, -1) 21 | -------------------------------------------------------------------------------- /bart/fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 10 | '''unfold T x B x C to T x B x C x K''' 11 | if kernel_size > 1: 12 | T, B, C = x.size() 13 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 14 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 15 | else: 16 | x = x.unsqueeze(3) 17 | return x 18 | -------------------------------------------------------------------------------- /bart/fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | from argparse import Namespace 9 | from typing import Union 10 | 11 | from fairseq import registry 12 | from fairseq.optim.bmuf import FairseqBMUF # noqa 13 | from fairseq.optim.fairseq_optimizer import ( # noqa 14 | FairseqOptimizer, 15 | LegacyFairseqOptimizer, 16 | ) 17 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 18 | from fairseq.optim.shard import shard_ 19 | from omegaconf import DictConfig 20 | 21 | 22 | __all__ = [ 23 | "FairseqOptimizer", 24 | "FP16Optimizer", 25 | "MemoryEfficientFP16Optimizer", 26 | "shard_", 27 | ] 28 | 29 | 30 | ( 31 | _build_optimizer, 32 | register_optimizer, 33 | OPTIMIZER_REGISTRY, 34 | OPTIMIZER_DATACLASS_REGISTRY, 35 | ) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) 36 | 37 | 38 | def build_optimizer( 39 | optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs 40 | ): 41 | if all(isinstance(p, dict) for p in params): 42 | params = [t for p in params for t in p.values()] 43 | params = list(filter(lambda p: p.requires_grad, params)) 44 | return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) 45 | 46 | 47 | # automatically import any Python files in the optim/ directory 48 | for file in os.listdir(os.path.dirname(__file__)): 49 | if file.endswith(".py") and not file.startswith("_"): 50 | file_name = file[: file.find(".py")] 51 | importlib.import_module("fairseq.optim." + file_name) 52 | -------------------------------------------------------------------------------- /bart/fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import register_optimizer, LegacyFairseqOptimizer 9 | 10 | 11 | @register_optimizer('adagrad') 12 | class Adagrad(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 22 | help='weight decay') 23 | # fmt: on 24 | 25 | @property 26 | def optimizer_config(self): 27 | """ 28 | Return a kwarg dictionary that will be used to override optimizer 29 | args stored in checkpoints. This allows us to load a checkpoint and 30 | resume training using a different set of optimizer args, e.g., with a 31 | different learning rate. 32 | """ 33 | return { 34 | 'lr': self.args.lr[0], 35 | 'weight_decay': self.args.weight_decay, 36 | } 37 | 38 | @property 39 | def supports_flat_params(self): 40 | return True 41 | -------------------------------------------------------------------------------- /bart/fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | from argparse import Namespace 9 | from typing import Union 10 | 11 | from fairseq import registry 12 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa 13 | FairseqLRScheduler, 14 | LegacyFairseqLRScheduler, 15 | ) 16 | from omegaconf import DictConfig 17 | 18 | 19 | ( 20 | build_lr_scheduler_, 21 | register_lr_scheduler, 22 | LR_SCHEDULER_REGISTRY, 23 | LR_SCHEDULER_DATACLASS_REGISTRY, 24 | ) = registry.setup_registry( 25 | "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" 26 | ) 27 | 28 | 29 | def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): 30 | return build_lr_scheduler_(lr_scheduler_cfg, optimizer) 31 | 32 | 33 | # automatically import any Python files in the optim/lr_scheduler/ directory 34 | for file in os.listdir(os.path.dirname(__file__)): 35 | if file.endswith(".py") and not file.startswith("_"): 36 | file_name = file[: file.find(".py")] 37 | importlib.import_module("fairseq.optim.lr_scheduler." + file_name) 38 | -------------------------------------------------------------------------------- /bart/fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.optim 7 | 8 | from . import register_optimizer, LegacyFairseqOptimizer 9 | 10 | 11 | @register_optimizer('sgd') 12 | class SGD(LegacyFairseqOptimizer): 13 | def __init__(self, args, params): 14 | super().__init__(args) 15 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | # fmt: off 21 | parser.add_argument('--momentum', default=0.0, type=float, metavar='M', 22 | help='momentum factor') 23 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 24 | help='weight decay') 25 | # fmt: on 26 | 27 | @property 28 | def optimizer_config(self): 29 | """ 30 | Return a kwarg dictionary that will be used to override optimizer 31 | args stored in checkpoints. This allows us to load a checkpoint and 32 | resume training using a different set of optimizer args, e.g., with a 33 | different learning rate. 34 | """ 35 | return { 36 | 'lr': self.args.lr[0], 37 | 'momentum': self.args.momentum, 38 | 'weight_decay': self.args.weight_decay, 39 | } 40 | 41 | @property 42 | def supports_flat_params(self): 43 | return True 44 | -------------------------------------------------------------------------------- /bart/fairseq/optim/shard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | try: 8 | from fairscale.optim import OSS 9 | _has_fairscale = True 10 | except ImportError: 11 | _has_fairscale = False 12 | 13 | 14 | def shard_(args, optimizer, group): 15 | if not _has_fairscale: 16 | raise ImportError( 17 | '\n\nPlease install the fairscale package:' 18 | '\n\n pip install fairscale' 19 | ) 20 | 21 | class FairseqOSS(OSS): 22 | @property 23 | def disable_mem_eff_fp16_loading_hack(self): 24 | return True 25 | 26 | def __getattr__(self, name): 27 | if name.startswith("supports") and hasattr(self.optim, name): 28 | return getattr(self.optim, name) 29 | raise AttributeError("'FairseqOSS' object has no attribute {0!r}".format(name)) 30 | 31 | torch_optimizer = optimizer.optimizer 32 | optim_cls = type(torch_optimizer) 33 | 34 | optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, group=group, **optimizer.optimizer_config) 35 | -------------------------------------------------------------------------------- /bart/fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import pdb 9 | import sys 10 | 11 | 12 | __all__ = ['set_trace'] 13 | 14 | 15 | _stdin = [None] 16 | _stdin_lock = multiprocessing.Lock() 17 | try: 18 | _stdin_fd = sys.stdin.fileno() 19 | except Exception: 20 | _stdin_fd = None 21 | 22 | 23 | class MultiprocessingPdb(pdb.Pdb): 24 | """A Pdb wrapper that works in a multiprocessing environment. 25 | 26 | Usage: `from fairseq import pdb; pdb.set_trace()` 27 | """ 28 | 29 | def __init__(self): 30 | pdb.Pdb.__init__(self, nosigint=True) 31 | 32 | def _cmdloop(self): 33 | stdin_bak = sys.stdin 34 | with _stdin_lock: 35 | try: 36 | if _stdin_fd is not None: 37 | if not _stdin[0]: 38 | _stdin[0] = os.fdopen(_stdin_fd) 39 | sys.stdin = _stdin[0] 40 | self.cmdloop() 41 | finally: 42 | sys.stdin = stdin_bak 43 | 44 | 45 | def set_trace(): 46 | pdb = MultiprocessingPdb() 47 | pdb.set_trace(sys._getframe().f_back) 48 | -------------------------------------------------------------------------------- /bart/fairseq/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | _build_scoring, register_scoring, SCORING_REGISTRY, _ = registry.setup_registry( 14 | "--scoring", default="bleu" 15 | ) 16 | 17 | 18 | def build_scorer(args, tgt_dict): 19 | from fairseq import utils 20 | 21 | if args.sacrebleu: 22 | utils.deprecation_warning( 23 | "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." 24 | ) 25 | args.scoring = "sacrebleu" 26 | if args.scoring == "bleu": 27 | from fairseq.scoring import bleu 28 | return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) 29 | else: 30 | return _build_scoring(args) 31 | 32 | 33 | # automatically import any Python files in the current directory 34 | for file in os.listdir(os.path.dirname(__file__)): 35 | if file.endswith(".py") and not file.startswith("_"): 36 | module = file[: file.find(".py")] 37 | importlib.import_module("fairseq.scoring." + module) 38 | -------------------------------------------------------------------------------- /bart/fairseq/scoring/wer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.scoring import register_scoring 7 | 8 | 9 | @register_scoring("wer") 10 | class WerScorer(object): 11 | def __init__(self, *unused): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.distance = 0 16 | self.ref_length = 0 17 | 18 | def add_string(self, ref, pred): 19 | import editdistance 20 | ref_items = ref.split() 21 | pred_items = pred.split() 22 | self.distance += editdistance.eval(ref_items, pred_items) 23 | self.ref_length += len(ref_items) 24 | 25 | def result_string(self): 26 | return f"WER: {self.score()}" 27 | 28 | def score(self): 29 | return ( 30 | 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 31 | ) 32 | -------------------------------------------------------------------------------- /bart/fairseq/tasks/translation_from_pretrained_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary 7 | from fairseq.tasks.translation import TranslationTask 8 | 9 | from . import register_task 10 | 11 | 12 | @register_task("translation_from_pretrained_xlm") 13 | class TranslationFromPretrainedXLMTask(TranslationTask): 14 | """ 15 | Same as TranslationTask except use the MaskedLMDictionary class so that 16 | we can load data that was binarized with the MaskedLMDictionary class. 17 | 18 | This task should be used for the entire training pipeline when we want to 19 | train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, 20 | training NMT with the pretrained XLM checkpoint, and subsequent evaluation 21 | of that trained model. 22 | """ 23 | 24 | @classmethod 25 | def load_dictionary(cls, filename): 26 | """Load the masked LM dictionary from the filename 27 | 28 | Args: 29 | filename (str): the filename 30 | """ 31 | return MaskedLMDictionary.load(filename) 32 | -------------------------------------------------------------------------------- /bart/fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | SPACE_NORMALIZER = re.compile(r"\s+") 9 | 10 | 11 | def tokenize_line(line): 12 | line = SPACE_NORMALIZER.sub(" ", line) 13 | line = line.strip() 14 | return line.split() 15 | -------------------------------------------------------------------------------- /bart/fairseq_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/fairseq_cli/__init__.py -------------------------------------------------------------------------------- /bart/my_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/10/26 10:23 上午 3 | # @Author : Xiachong Feng 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /bart/my_scripts/binarize.sh: -------------------------------------------------------------------------------- 1 | data=data 2 | destdir=data/bin 3 | 4 | python fairseq_cli/preprocess.py \ 5 | --source-lang "source" \ 6 | --target-lang "target" \ 7 | --trainpref $data/train.bpe \ 8 | --validpref $data/valid.bpe \ 9 | --destdir $destdir \ 10 | --workers 60 \ 11 | --srcdict data/bpe/dict.txt \ 12 | --tgtdict data/bpe/dict.txt -------------------------------------------------------------------------------- /bart/my_scripts/bpe.sh: -------------------------------------------------------------------------------- 1 | DIR=data 2 | for SPLIT in train valid 3 | do 4 | for LANG in source target 5 | do 6 | python -m examples.roberta.multiprocessing_bpe_encoder \ 7 | --encoder-json data/bpe/encoder.json \ 8 | --vocab-bpe data/bpe/vocab.bpe \ 9 | --inputs "$DIR/$SPLIT.$LANG" \ 10 | --outputs "$DIR/$SPLIT.bpe.$LANG" \ 11 | --workers 60 \ 12 | --keep-empty; 13 | done 14 | done -------------------------------------------------------------------------------- /bart/my_scripts/infer.sh: -------------------------------------------------------------------------------- 1 | model_name=samsum 2 | ckpt_dir=ckpt 3 | out_name=samsum 4 | CUDA_VISIBLE_DEVICES=0 python fairseq_cli/inference.py \ 5 | --ckpt_dir ${ckpt_dir} \ 6 | --ckpt_file ${model_name}.pt \ 7 | --data_dir data/bin \ 8 | --test_file data/test.source \ 9 | --output_file summaries/${out_name}.txt \ 10 | --batch_size 50 \ 11 | --beam_size 4 \ 12 | --min_len 5 \ 13 | --max_len 100 \ 14 | --block_ngram 3 \ 15 | --len_penalty 0.5 16 | 17 | 18 | # model_name=checkpoint_best 19 | # ckpt_dir=ckpt/main 20 | # out_name=main 21 | # CUDA_VISIBLE_DEVICES=0 python fairseq_cli/inference.py \ 22 | # --ckpt_dir ${ckpt_dir} \ 23 | # --ckpt_file ${model_name}.pt \ 24 | # --data_dir data/bin \ 25 | # --test_file data/test.source \ 26 | # --output_file summaries/${out_name}.txt \ 27 | # --batch_size 10 \ 28 | # --beam_size 4 \ 29 | # --min_len 5 \ 30 | # --max_len 100 \ 31 | # --block_ngram 3 \ 32 | # --len_penalty 0.5 -------------------------------------------------------------------------------- /bart/my_scripts/train.sh: -------------------------------------------------------------------------------- 1 | cuda=0 2 | data=data/bin 3 | save_dir=main 4 | 5 | warmup_updates=400 6 | lr=3e-05 7 | dropout=0.1 8 | update_freq=32 9 | max_tokens=800 10 | total_num_update=100000 11 | bart=bart/bart.large/model.pt 12 | 13 | CUDA_VISIBLE_DEVICES=$cuda python train.py $data \ 14 | --restore-file $bart \ 15 | --max-tokens $max_tokens \ 16 | --task translation \ 17 | --source-lang source --target-lang target \ 18 | --truncate-source \ 19 | --layernorm-embedding \ 20 | --share-all-embeddings \ 21 | --share-decoder-input-output-embed \ 22 | --reset-optimizer --reset-dataloader --reset-meters \ 23 | --required-batch-size-multiple 1 \ 24 | --arch bart_large \ 25 | --criterion label_smoothed_cross_entropy \ 26 | --label-smoothing 0.1 \ 27 | --dropout $dropout --attention-dropout 0.1 \ 28 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ 29 | --clip-norm 0.1 \ 30 | --lr-scheduler polynomial_decay \ 31 | --lr $lr \ 32 | --update-freq $update_freq \ 33 | --skip-invalid-size-inputs-valid-test \ 34 | --find-unused-parameters \ 35 | --total-num-update $total_num_update \ 36 | --warmup-updates $warmup_updates \ 37 | --no-epoch-checkpoints \ 38 | --save-dir ckpt/$save_dir -------------------------------------------------------------------------------- /bart/requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | py-rouge 3 | nltk -------------------------------------------------------------------------------- /bart/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/bart/scripts/__init__.py -------------------------------------------------------------------------------- /bart/scripts/compare_namespaces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script to compare two argparse.Namespace objects.""" 3 | 4 | from argparse import Namespace # noqa 5 | 6 | 7 | def main(): 8 | 9 | ns1 = eval(input('Namespace 1: ')) 10 | ns2 = eval(input('Namespace 2: ')) 11 | 12 | def keys(ns): 13 | ks = set() 14 | for k in dir(ns): 15 | if not k.startswith('_'): 16 | ks.add(k) 17 | return ks 18 | 19 | k1 = keys(ns1) 20 | k2 = keys(ns2) 21 | 22 | def print_keys(ks, ns1, ns2=None): 23 | for k in ks: 24 | if ns2 is None: 25 | print('{}\t{}'.format(k, getattr(ns1, k, None))) 26 | else: 27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None))) 28 | 29 | print('Keys unique to namespace 1:') 30 | print_keys(k1 - k2, ns1) 31 | print() 32 | 33 | print('Keys unique to namespace 2:') 34 | print_keys(k2 - k1, ns2) 35 | print() 36 | 37 | print('Overlapping keys with different values:') 38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')] 39 | print_keys(ks, ns1, ns2) 40 | print() 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /bart/scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /bart/scripts/constraints/validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import sys 9 | 10 | """Reads in a fairseq output file, and verifies that the constraints 11 | (C- lines) are present in the output (the first H- line). Assumes that 12 | constraints are listed prior to the first hypothesis. 13 | """ 14 | 15 | constraints = [] 16 | found = 0 17 | total = 0 18 | for line in sys.stdin: 19 | if line.startswith("C-"): 20 | constraints.append(line.rstrip().split("\t")[1]) 21 | elif line.startswith("H-"): 22 | text = line.split("\t")[2] 23 | 24 | for constraint in constraints: 25 | total += 1 26 | if constraint in text: 27 | found += 1 28 | else: 29 | print(f"No {constraint} in {text}", file=sys.stderr) 30 | 31 | constraints = [] 32 | 33 | print(f"Found {found} / {total} = {100 * found / total:.1f}%") 34 | -------------------------------------------------------------------------------- /bart/scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) Facebook, Inc. and its affiliates. 2 | -- 3 | -- This source code is licensed under the MIT license found in the 4 | -- LICENSE file in the root directory of this source tree. 5 | -- 6 | -- Usage: convert_dictionary.lua 7 | require 'fairseq' 8 | require 'torch' 9 | require 'paths' 10 | 11 | if #arg < 1 then 12 | print('usage: convert_dictionary.lua ') 13 | os.exit(1) 14 | end 15 | if not paths.filep(arg[1]) then 16 | print('error: file does not exit: ' .. arg[1]) 17 | os.exit(1) 18 | end 19 | 20 | dict = torch.load(arg[1]) 21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 22 | assert(dst:match('.txt$')) 23 | 24 | f = io.open(dst, 'w') 25 | for idx, symbol in ipairs(dict.index_to_symbol) do 26 | if idx > dict.cutoff then 27 | break 28 | end 29 | f:write(symbol) 30 | f:write(' ') 31 | f:write(dict.index_to_freq[idx]) 32 | f:write('\n') 33 | end 34 | f:close() 35 | -------------------------------------------------------------------------------- /bart/scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | from fairseq.data import data_utils, Dictionary, indexed_dataset 10 | 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser( 14 | description='writes text from binarized file to stdout') 15 | # fmt: off 16 | parser.add_argument('--dataset-impl', help='dataset implementation', 17 | choices=indexed_dataset.get_available_dataset_impl()) 18 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) 19 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 20 | # fmt: on 21 | 22 | return parser 23 | 24 | 25 | def main(): 26 | parser = get_parser() 27 | args = parser.parse_args() 28 | 29 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None 30 | dataset = data_utils.load_indexed_dataset( 31 | args.input, 32 | dictionary, 33 | dataset_impl=args.dataset_impl, 34 | default='lazy', 35 | ) 36 | 37 | for tensor_line in dataset: 38 | if dictionary is None: 39 | line = ' '.join([str(int(x)) for x in tensor_line]) 40 | else: 41 | line = dictionary.string(tensor_line) 42 | 43 | print(line) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /bart/scripts/sacrebleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | if ! command -v sacremoses &> /dev/null 15 | then 16 | echo "sacremoses could not be found, please install with: pip install sacremoses" 17 | exit 18 | fi 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | sacremoses detokenize \ 25 | > $GEN.sorted.detok 26 | 27 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 28 | -------------------------------------------------------------------------------- /bart/scripts/shard_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Split a large file into shards while respecting document boundaries. Documents 8 | should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import contextlib 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('input') 18 | parser.add_argument('--num-shards', type=int) 19 | args = parser.parse_args() 20 | 21 | assert args.num_shards is not None and args.num_shards > 1 22 | 23 | with open(args.input, 'r', encoding='utf-8') as h: 24 | with contextlib.ExitStack() as stack: 25 | outputs = [ 26 | stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8")) 27 | for i in range(args.num_shards) 28 | ] 29 | 30 | doc = [] 31 | first_doc = [True]*args.num_shards 32 | 33 | def output_doc(i): 34 | if not first_doc[i]: 35 | outputs[i].write("\n") 36 | first_doc[i] = False 37 | for line in doc: 38 | outputs[i].write(line) 39 | doc.clear() 40 | 41 | num_docs = 0 42 | for line in h: 43 | if line.strip() == "": # empty line indicates new document 44 | output_doc(num_docs % args.num_shards) 45 | num_docs += 1 46 | else: 47 | doc.append(line) 48 | output_doc(num_docs % args.num_shards) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /bart/scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", required=True, 18 | help="sentencepiece model to use for decoding") 19 | parser.add_argument("--input", required=True, help="input file to decode") 20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 21 | args = parser.parse_args() 22 | 23 | sp = spm.SentencePieceProcessor() 24 | sp.Load(args.model) 25 | 26 | if args.input_format == "piece": 27 | def decode(l): 28 | return "".join(sp.DecodePieces(l)) 29 | elif args.input_format == "id": 30 | def decode(l): 31 | return "".join(sp.DecodeIds(l)) 32 | else: 33 | raise NotImplementedError 34 | 35 | def tok2int(tok): 36 | # remap reference-side (represented as <>) to 0 37 | return int(tok) if tok != "<>" else 0 38 | 39 | with open(args.input, "r", encoding="utf-8") as h: 40 | for line in h: 41 | if args.input_format == "id": 42 | print(decode(list(map(tok2int, line.rstrip().split())))) 43 | elif args.input_format == "piece": 44 | print(decode(line.rstrip().split())) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /bart/scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import sys 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | if __name__ == "__main__": 16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) 17 | -------------------------------------------------------------------------------- /bart/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. 8 | """ 9 | 10 | from fairseq_cli.train import cli_main 11 | 12 | 13 | if __name__ == '__main__': 14 | cli_main() 15 | -------------------------------------------------------------------------------- /pgn/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /pgn/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* -------------------------------------------------------------------------------- /pgn/README.md: -------------------------------------------------------------------------------- 1 | # Pointer-Generator for AMI Meeting Dataset 2 | This code is based on [OpenNMT](https://github.com/OpenNMT/OpenNMT-py). 3 | 4 | ## Requirements 5 | * We use Conda python 3.7 and strongly recommend that you create a new environment. 6 | * `conda create -n pgn python=3.7`. 7 | * Run the following command. 8 | * `pip install -r requirements.txt`. 9 | 10 | ## Data 11 | You can get data [here](https://drive.google.com/drive/folders/1VjuDhFxiv8t590-s_4HTX6BqOHhU89Ci?usp=sharing). Put them under the dir **data/\***. 12 | 13 | ## Reproduce Results 14 | You can follow the following steps to reproduce the best results in our paper. 15 | 16 | ### download checkpoints 17 | Download checkpoints [here](https://drive.google.com/drive/folders/1A9xjS_x1yhjwmtmOlyur16LCvOoOprwL?usp=sharing). Put the checkpoint under the dir **ckpt/ami.pt**. 18 | 19 | ### translate 20 | * `sh ./scripts/infer.sh` 21 | 22 | ### test rouge score 23 | * Change `pyrouge.Rouge155()` to your local path. 24 | * Output format `>> ROUGE(1/2/L): xx.xx-xx.xx-xx.xx` 25 | * `python test_rouge.py -c summaries/ami.txt` 26 | 27 | ### ROUGE score 28 | You will get following ROUGE scores. 29 | 30 | ||ROUGE-1| ROUGE-2 | ROUGE-L | 31 | | :---: | :---: | :---: | :---: | 32 | | AMI | 50.91 | 17.75 | 24.59 | 33 | 34 | ## From Scratch 35 | ### Preprocess 36 | Run the following commands: 37 | * `sh ./scripts/preprocess.sh` 38 | * `sh ./scripts/embedding.sh` 39 | 40 | ### Train 41 | Run the following command: 42 | * `sh ./scripts/train.sh` 43 | 44 | ### Translate 45 | Run the following command: 46 | * `sh ./scripts/infer.sh` 47 | * set up **model_name** param first. -------------------------------------------------------------------------------- /pgn/ckpt/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/ckpt/.gitignore -------------------------------------------------------------------------------- /pgn/data/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/data/.gitignore -------------------------------------------------------------------------------- /pgn/logs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/logs/.gitignore -------------------------------------------------------------------------------- /pgn/onmt/__init__.py: -------------------------------------------------------------------------------- 1 | """ Main entry point of the ONMT library """ 2 | from __future__ import division, print_function 3 | 4 | import onmt.inputters 5 | import onmt.encoders 6 | import onmt.decoders 7 | import onmt.models 8 | import onmt.utils 9 | import onmt.modules 10 | from onmt.trainer import Trainer 11 | import sys 12 | import onmt.utils.optimizers 13 | onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer 14 | sys.modules["onmt.Optim"] = onmt.utils.optimizers 15 | 16 | # For Flake 17 | __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, 18 | onmt.utils, onmt.modules, "Trainer"] 19 | 20 | __version__ = "1.0.0.rc2" 21 | -------------------------------------------------------------------------------- /pgn/onmt/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/__pycache__/model_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/__pycache__/model_builder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/__pycache__/opts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/__pycache__/opts.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/__pycache__/train_single.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/__pycache__/train_single.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/bin/__init__.py -------------------------------------------------------------------------------- /pgn/onmt/bin/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/bin/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/bin/__pycache__/preprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/bin/__pycache__/preprocess.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/bin/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/bin/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/bin/__pycache__/translate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/bin/__pycache__/translate.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/bin/translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | from itertools import repeat 6 | 7 | from onmt.utils.logging import init_logger 8 | from onmt.utils.misc import split_corpus 9 | from onmt.translate.translator import build_translator 10 | 11 | import onmt.opts as opts 12 | from onmt.utils.parse import ArgumentParser 13 | 14 | 15 | def translate(opt): 16 | ArgumentParser.validate_translate_opts(opt) 17 | logger = init_logger(opt.log_file) 18 | 19 | translator = build_translator(opt, report_score=True) 20 | src_shards = split_corpus(opt.src, opt.shard_size) 21 | tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ 22 | if opt.tgt is not None else repeat(None) 23 | shard_pairs = zip(src_shards, tgt_shards) 24 | 25 | for i, (src_shard, tgt_shard) in enumerate(shard_pairs): 26 | logger.info("Translating shard %d." % i) 27 | translator.translate( 28 | src=src_shard, 29 | tgt=tgt_shard, 30 | src_dir=opt.src_dir, 31 | batch_size=opt.batch_size, 32 | batch_type=opt.batch_type, 33 | attn_debug=opt.attn_debug 34 | ) 35 | 36 | 37 | def _get_parser(): 38 | parser = ArgumentParser(description='translate.py') 39 | 40 | opts.config_opts(parser) 41 | opts.translate_opts(parser) 42 | return parser 43 | 44 | 45 | def main(): 46 | parser = _get_parser() 47 | 48 | opt = parser.parse_args() 49 | translate(opt) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /pgn/onmt/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining decoders.""" 2 | from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ 3 | StdRNNDecoder 4 | from onmt.decoders.transformer import TransformerDecoder 5 | from onmt.decoders.cnn_decoder import CNNDecoder 6 | 7 | 8 | str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, 9 | "cnn": CNNDecoder, "transformer": TransformerDecoder} 10 | 11 | __all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", 12 | "InputFeedRNNDecoder", "str2dec"] 13 | -------------------------------------------------------------------------------- /pgn/onmt/decoders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/decoders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/decoders/__pycache__/cnn_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/decoders/__pycache__/cnn_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/decoders/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/decoders/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/decoders/__pycache__/ensemble.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/decoders/__pycache__/ensemble.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/decoders/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/decoders/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining encoders.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.encoders.transformer import TransformerEncoder 4 | from onmt.encoders.rnn_encoder import RNNEncoder 5 | from onmt.encoders.cnn_encoder import CNNEncoder 6 | from onmt.encoders.mean_encoder import MeanEncoder 7 | from onmt.encoders.audio_encoder import AudioEncoder 8 | from onmt.encoders.image_encoder import ImageEncoder 9 | 10 | 11 | str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, 12 | "transformer": TransformerEncoder, "img": ImageEncoder, 13 | "audio": AudioEncoder, "mean": MeanEncoder} 14 | 15 | __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", 16 | "MeanEncoder", "str2enc"] 17 | -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/audio_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/audio_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/cnn_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/cnn_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/image_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/image_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/mean_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/mean_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/rnn_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/rnn_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/encoders/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for encoders and generic multi encoders.""" 2 | 3 | import torch.nn as nn 4 | 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class EncoderBase(nn.Module): 9 | """ 10 | Base encoder class. Specifies the interface used by different encoder types 11 | and required by :class:`onmt.Models.NMTModel`. 12 | 13 | .. mermaid:: 14 | 15 | graph BT 16 | A[Input] 17 | subgraph RNN 18 | C[Pos 1] 19 | D[Pos 2] 20 | E[Pos N] 21 | end 22 | F[Memory_Bank] 23 | G[Final] 24 | A-->C 25 | A-->D 26 | A-->E 27 | C-->F 28 | D-->F 29 | E-->F 30 | E-->G 31 | """ 32 | 33 | @classmethod 34 | def from_opt(cls, opt, embeddings=None): 35 | raise NotImplementedError 36 | 37 | def _check_args(self, src, lengths=None, hidden=None): 38 | n_batch = src.size(1) 39 | if lengths is not None: 40 | n_batch_, = lengths.size() 41 | aeq(n_batch, n_batch_) 42 | 43 | def forward(self, src, lengths=None): 44 | """ 45 | Args: 46 | src (LongTensor): 47 | padded sequences of sparse indices ``(src_len, batch, nfeat)`` 48 | lengths (LongTensor): length of each sequence ``(batch,)`` 49 | 50 | 51 | Returns: 52 | (FloatTensor, FloatTensor): 53 | 54 | * final encoder state, used to initialize decoder 55 | * memory bank for attention, ``(src_len, batch, hidden)`` 56 | """ 57 | 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /pgn/onmt/encoders/mean_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.utils.misc import sequence_mask 4 | import torch 5 | 6 | 7 | class MeanEncoder(EncoderBase): 8 | """A trivial non-recurrent encoder. Simply applies mean pooling. 9 | 10 | Args: 11 | num_layers (int): number of replicated layers 12 | embeddings (onmt.modules.Embeddings): embedding module to use 13 | """ 14 | 15 | def __init__(self, num_layers, embeddings): 16 | super(MeanEncoder, self).__init__() 17 | self.num_layers = num_layers 18 | self.embeddings = embeddings 19 | 20 | @classmethod 21 | def from_opt(cls, opt, embeddings): 22 | """Alternate constructor.""" 23 | return cls( 24 | opt.enc_layers, 25 | embeddings) 26 | 27 | def forward(self, src, lengths=None): 28 | """See :func:`EncoderBase.forward()`""" 29 | self._check_args(src, lengths) 30 | 31 | emb = self.embeddings(src) 32 | _, batch, emb_dim = emb.size() 33 | 34 | if lengths is not None: 35 | # we avoid padding while mean pooling 36 | mask = sequence_mask(lengths).float() 37 | mask = mask / lengths.unsqueeze(1).float() 38 | mean = torch.bmm(mask.unsqueeze(1), emb.transpose(0, 1)).squeeze(1) 39 | else: 40 | mean = emb.mean(0) 41 | 42 | mean = mean.expand(self.num_layers, batch, emb_dim) 43 | memory_bank = emb 44 | encoder_final = (mean, mean) 45 | return encoder_final, memory_bank, lengths 46 | -------------------------------------------------------------------------------- /pgn/onmt/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining inputters. 2 | 3 | Inputters implement the logic of transforming raw data to vectorized inputs, 4 | e.g., from a line of text to a sequence of embeddings. 5 | """ 6 | from onmt.inputters.inputter import \ 7 | load_old_vocab, get_fields, OrderedIterator, \ 8 | build_vocab, old_style_vocab, filter_example 9 | from onmt.inputters.dataset_base import Dataset 10 | from onmt.inputters.text_dataset import text_sort_key, TextDataReader 11 | from onmt.inputters.image_dataset import img_sort_key, ImageDataReader 12 | from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader 13 | from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader 14 | from onmt.inputters.datareader_base import DataReaderBase 15 | 16 | 17 | str2reader = { 18 | "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader, 19 | "vec": VecDataReader} 20 | str2sortkey = { 21 | 'text': text_sort_key, 'img': img_sort_key, 'audio': audio_sort_key, 22 | 'vec': vec_sort_key} 23 | 24 | 25 | __all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'DataReaderBase', 26 | 'filter_example', 'old_style_vocab', 27 | 'build_vocab', 'OrderedIterator', 28 | 'text_sort_key', 'img_sort_key', 'audio_sort_key', 'vec_sort_key', 29 | 'TextDataReader', 'ImageDataReader', 'AudioDataReader', 30 | 'VecDataReader'] 31 | -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/audio_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/audio_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/datareader_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/datareader_base.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/dataset_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/dataset_base.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/image_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/image_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/inputter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/inputter.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/text_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/text_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/__pycache__/vec_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/inputters/__pycache__/vec_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/inputters/datareader_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | # several data readers need optional dependencies. There's no 5 | # appropriate builtin exception 6 | class MissingDependencyException(Exception): 7 | pass 8 | 9 | 10 | class DataReaderBase(object): 11 | """Read data from file system and yield as dicts. 12 | 13 | Raises: 14 | onmt.inputters.datareader_base.MissingDependencyException: A number 15 | of DataReaders need specific additional packages. 16 | If any are missing, this will be raised. 17 | """ 18 | 19 | @classmethod 20 | def from_opt(cls, opt): 21 | """Alternative constructor. 22 | 23 | Args: 24 | opt (argparse.Namespace): The parsed arguments. 25 | """ 26 | 27 | return cls() 28 | 29 | @classmethod 30 | def _read_file(cls, path): 31 | """Line-by-line read a file as bytes.""" 32 | with open(path, "rb") as f: 33 | for line in f: 34 | yield line 35 | 36 | @staticmethod 37 | def _raise_missing_dep(*missing_deps): 38 | """Raise missing dep exception with standard error message.""" 39 | raise MissingDependencyException( 40 | "Could not create reader. Be sure to install " 41 | "the following dependencies: " + ", ".join(missing_deps)) 42 | 43 | def read(self, data, side, src_dir): 44 | """Read data from file system and yield as dicts.""" 45 | raise NotImplementedError() 46 | -------------------------------------------------------------------------------- /pgn/onmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining models.""" 2 | from onmt.models.model_saver import build_model_saver, ModelSaver 3 | from onmt.models.model import NMTModel 4 | 5 | __all__ = ["build_model_saver", "ModelSaver", "NMTModel"] 6 | -------------------------------------------------------------------------------- /pgn/onmt/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/models/__pycache__/model_saver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/models/__pycache__/model_saver.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/models/__pycache__/sru.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/models/__pycache__/sru.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/models/__pycache__/stacked_rnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/models/__pycache__/stacked_rnn.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ Attention and normalization modules """ 2 | from onmt.modules.util_class import Elementwise 3 | from onmt.modules.gate import context_gate_factory, ContextGate 4 | from onmt.modules.global_attention import GlobalAttention 5 | from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention 6 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ 7 | CopyGeneratorLossCompute 8 | from onmt.modules.multi_headed_attn import MultiHeadedAttention 9 | from onmt.modules.embeddings import Embeddings, PositionalEncoding, \ 10 | VecEmbedding 11 | from onmt.modules.weight_norm import WeightNormConv2d 12 | from onmt.modules.average_attn import AverageAttention 13 | 14 | __all__ = ["Elementwise", "context_gate_factory", "ContextGate", 15 | "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", 16 | "CopyGeneratorLoss", "CopyGeneratorLossCompute", 17 | "MultiHeadedAttention", "Embeddings", "PositionalEncoding", 18 | "WeightNormConv2d", "AverageAttention", "VecEmbedding"] 19 | -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/average_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/average_attn.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/conv_multi_step_attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/conv_multi_step_attention.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/copy_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/copy_generator.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/embeddings.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/embeddings.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/gate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/gate.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/global_attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/global_attention.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/multi_headed_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/multi_headed_attn.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/position_ffn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/position_ffn.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/sparse_activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/sparse_activations.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/sparse_losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/sparse_losses.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/util_class.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/util_class.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/__pycache__/weight_norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/modules/__pycache__/weight_norm.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """Position feed-forward network from "Attention is All You Need".""" 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class PositionwiseFeedForward(nn.Module): 7 | """ A two-layer Feed-Forward-Network with residual layer norm. 8 | 9 | Args: 10 | d_model (int): the size of input for the first-layer of the FFN. 11 | d_ff (int): the hidden layer size of the second-layer 12 | of the FNN. 13 | dropout (float): dropout probability in :math:`[0, 1)`. 14 | """ 15 | 16 | def __init__(self, d_model, d_ff, dropout=0.1): 17 | super(PositionwiseFeedForward, self).__init__() 18 | self.w_1 = nn.Linear(d_model, d_ff) 19 | self.w_2 = nn.Linear(d_ff, d_model) 20 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 21 | self.dropout_1 = nn.Dropout(dropout) 22 | self.relu = nn.ReLU() 23 | self.dropout_2 = nn.Dropout(dropout) 24 | 25 | def forward(self, x): 26 | """Layer definition. 27 | 28 | Args: 29 | x: ``(batch_size, input_len, model_dim)`` 30 | 31 | Returns: 32 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``. 33 | """ 34 | 35 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 36 | output = self.dropout_2(self.w_2(inter)) 37 | return output + x 38 | 39 | def update_dropout(self, dropout): 40 | self.dropout_1.p = dropout 41 | self.dropout_2.p = dropout 42 | -------------------------------------------------------------------------------- /pgn/onmt/modules/structured_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.cuda 4 | 5 | 6 | class MatrixTree(nn.Module): 7 | """Implementation of the matrix-tree theorem for computing marginals 8 | of non-projective dependency parsing. This attention layer is used 9 | in the paper "Learning Structured Text Representations" 10 | :cite:`DBLP:journals/corr/LiuL17d`. 11 | """ 12 | 13 | def __init__(self, eps=1e-5): 14 | self.eps = eps 15 | super(MatrixTree, self).__init__() 16 | 17 | def forward(self, input): 18 | laplacian = input.exp() + self.eps 19 | output = input.clone() 20 | for b in range(input.size(0)): 21 | lap = laplacian[b].masked_fill( 22 | torch.eye(input.size(1), device=input.device).ne(0), 0) 23 | lap = -lap + torch.diag(lap.sum(0)) 24 | # store roots on diagonal 25 | lap[0] = input[b].diag().exp() 26 | inv_laplacian = lap.inverse() 27 | 28 | factor = inv_laplacian.diag().unsqueeze(1)\ 29 | .expand_as(input[b]).transpose(0, 1) 30 | term1 = input[b].exp().mul(factor).clone() 31 | term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() 32 | term1[:, 0] = 0 33 | term2[0] = 0 34 | output[b] = term1 - term2 35 | roots_output = input[b].diag().exp().mul( 36 | inv_laplacian.transpose(0, 1)[0]) 37 | output[b] = output[b] + torch.diag(roots_output) 38 | return output 39 | -------------------------------------------------------------------------------- /pgn/onmt/modules/util_class.py: -------------------------------------------------------------------------------- 1 | """ Misc classes """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # At the moment this class is only used by embeddings.Embeddings look-up tables 7 | class Elementwise(nn.ModuleList): 8 | """ 9 | A simple network container. 10 | Parameters are a list of modules. 11 | Inputs are a 3d Tensor whose last dimension is the same length 12 | as the list. 13 | Outputs are the result of applying modules to inputs elementwise. 14 | An optional merge parameter allows the outputs to be reduced to a 15 | single Tensor. 16 | """ 17 | 18 | def __init__(self, merge=None, *args): 19 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 20 | self.merge = merge 21 | super(Elementwise, self).__init__(*args) 22 | 23 | def forward(self, inputs): 24 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 25 | assert len(self) == len(inputs_) 26 | outputs = [f(x) for f, x in zip(self, inputs_)] 27 | if self.merge == 'first': 28 | return outputs[0] 29 | elif self.merge == 'concat' or self.merge == 'mlp': 30 | return torch.cat(outputs, 2) 31 | elif self.merge == 'sum': 32 | return sum(outputs) 33 | else: 34 | return outputs 35 | 36 | 37 | class Cast(nn.Module): 38 | """ 39 | Basic layer that casts its input to a specific data type. The same tensor 40 | is returned if the data type is already correct. 41 | """ 42 | 43 | def __init__(self, dtype): 44 | super(Cast, self).__init__() 45 | self._dtype = dtype 46 | 47 | def forward(self, x): 48 | return x.to(self._dtype) 49 | -------------------------------------------------------------------------------- /pgn/onmt/translate/__init__.py: -------------------------------------------------------------------------------- 1 | """ Modules for translation """ 2 | from onmt.translate.translator import Translator 3 | from onmt.translate.translation import Translation, TranslationBuilder 4 | from onmt.translate.beam import Beam, GNMTGlobalScorer 5 | from onmt.translate.beam_search import BeamSearch 6 | from onmt.translate.decode_strategy import DecodeStrategy 7 | from onmt.translate.random_sampling import RandomSampling 8 | from onmt.translate.penalties import PenaltyBuilder 9 | from onmt.translate.translation_server import TranslationServer, \ 10 | ServerModelError 11 | 12 | __all__ = ['Translator', 'Translation', 'Beam', 'BeamSearch', 13 | 'GNMTGlobalScorer', 'TranslationBuilder', 14 | 'PenaltyBuilder', 'TranslationServer', 'ServerModelError', 15 | "DecodeStrategy", "RandomSampling"] 16 | -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/beam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/beam.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/beam_search.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/beam_search.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/decode_strategy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/decode_strategy.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/penalties.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/penalties.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/random_sampling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/random_sampling.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/translation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/translation.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/translation_server.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/translation_server.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/__pycache__/translator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/translate/__pycache__/translator.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/translate/process_zh.py: -------------------------------------------------------------------------------- 1 | from pyhanlp import HanLP 2 | from snownlp import SnowNLP 3 | import pkuseg 4 | 5 | 6 | # Chinese segmentation 7 | def zh_segmentator(line): 8 | return " ".join(pkuseg.pkuseg().cut(line)) 9 | 10 | 11 | # Chinese simplify -> Chinese traditional standard 12 | def zh_traditional_standard(line): 13 | return HanLP.convertToTraditionalChinese(line) 14 | 15 | 16 | # Chinese simplify -> Chinese traditional (HongKong) 17 | def zh_traditional_hk(line): 18 | return HanLP.s2hk(line) 19 | 20 | 21 | # Chinese simplify -> Chinese traditional (Taiwan) 22 | def zh_traditional_tw(line): 23 | return HanLP.s2tw(line) 24 | 25 | 26 | # Chinese traditional -> Chinese simplify (v1) 27 | def zh_simplify(line): 28 | return HanLP.convertToSimplifiedChinese(line) 29 | 30 | 31 | # Chinese traditional -> Chinese simplify (v2) 32 | def zh_simplify_v2(line): 33 | return SnowNLP(line).han 34 | -------------------------------------------------------------------------------- /pgn/onmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining various utilities.""" 2 | from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed 3 | from onmt.utils.report_manager import ReportMgr, build_report_manager 4 | from onmt.utils.statistics import Statistics 5 | from onmt.utils.optimizers import MultipleOptimizer, \ 6 | Optimizer, AdaFactor 7 | from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts 8 | 9 | __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", 10 | "build_report_manager", "Statistics", 11 | "MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping", 12 | "scorers_from_opts"] 13 | -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/cnn_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/cnn_factory.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/distributed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/distributed.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/earlystopping.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/earlystopping.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/logging.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/logging.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/optimizers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/optimizers.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/parse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/parse.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/report_manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/report_manager.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/rnn_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/rnn_factory.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/__pycache__/statistics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pgn/onmt/utils/__pycache__/statistics.cpython-37.pyc -------------------------------------------------------------------------------- /pgn/onmt/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /pgn/onmt/utils/rnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN tools 3 | """ 4 | import torch.nn as nn 5 | import onmt.models 6 | 7 | 8 | def rnn_factory(rnn_type, **kwargs): 9 | """ rnn factory, Use pytorch version when available. """ 10 | no_pack_padded_seq = False 11 | if rnn_type == "SRU": 12 | # SRU doesn't support PackedSequence. 13 | no_pack_padded_seq = True 14 | rnn = onmt.models.sru.SRU(**kwargs) 15 | else: 16 | rnn = getattr(nn, rnn_type)(**kwargs) 17 | return rnn, no_pack_padded_seq 18 | -------------------------------------------------------------------------------- /pgn/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from onmt.bin.preprocess import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /pgn/requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | tqdm==4.30.0 3 | torch==1.2 4 | torchtext==0.4.0 5 | future 6 | configargparse 7 | tensorboard==1.14 8 | flask 9 | pyonmttok==1.*;platform_system=='Linux' 10 | PyYAML -------------------------------------------------------------------------------- /pgn/scripts/embedding.sh: -------------------------------------------------------------------------------- 1 | dict_file=data/ami.vocab.pt 2 | output_file=data/embeddings 3 | 4 | python embeddings_to_torch.py -emb_file_both data/glove.6B.300d.txt \ 5 | -dict_file ${dict_file} \ 6 | -output_file ${output_file} -------------------------------------------------------------------------------- /pgn/scripts/infer.sh: -------------------------------------------------------------------------------- 1 | cuda=0 2 | model_name=ami 3 | data_prefix=data 4 | model=ckpt/${model_name}.pt 5 | output=summaries/${model_name}.txt 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} python translate.py -batch_size 1 \ 8 | -src ${data_prefix}/test.txt.src \ 9 | -tgt ${data_prefix}/test.txt.tgt \ 10 | -beam_size 10 \ 11 | -share_vocab \ 12 | -dynamic_dict \ 13 | -replace_unk \ 14 | -model ${model} \ 15 | -output ${output} \ 16 | -block_ngram_repeat 3 \ 17 | -gpu 0 \ 18 | -min_length 280 \ 19 | -max_length 450 20 | 21 | sed -i 's/ <\/t>//g' ${output} 22 | sed -i 's/ //g' ${output} -------------------------------------------------------------------------------- /pgn/scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | data=data 2 | 3 | python preprocess.py -train_src ${data}/train.txt.src \ 4 | -train_tgt ${data}/train.txt.tgt \ 5 | -valid_src ${data}/valid.txt.src \ 6 | -valid_tgt ${data}/valid.txt.tgt \ 7 | -shard_size 100 \ 8 | -src_vocab_size 10000 \ 9 | -tgt_vocab_size 10000 \ 10 | -src_words_min_frequency 2 \ 11 | -tgt_words_min_frequency 2 \ 12 | -src_seq_length 15000 \ 13 | -src_seq_length_trunc 11000 \ 14 | -tgt_seq_length 700 \ 15 | -tgt_seq_length_trunc 700 \ 16 | -save_data ${data}/ami \ 17 | -dynamic_dict \ 18 | -share_vocab \ 19 | -lower \ 20 | -overwrite -------------------------------------------------------------------------------- /pgn/scripts/train.sh: -------------------------------------------------------------------------------- 1 | cuda=0 2 | name=ami 3 | data_prefix=data 4 | data=${data_prefix}/ami 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} python train.py -save_model ckpt/${name} \ 7 | -data ${data} \ 8 | -batch_size 1 \ 9 | -learning_rate 0.001 \ 10 | -share_embeddings \ 11 | -pre_word_vecs_enc ${data_prefix}/embeddings.enc.pt \ 12 | -pre_word_vecs_dec ${data_prefix}/embeddings.dec.pt \ 13 | -save_checkpoint_steps 100 \ 14 | -seed 777 \ 15 | -optim adam \ 16 | -max_grad_norm 2 \ 17 | -report_every 100 \ 18 | -word_vec_size 300 \ 19 | -encoder_type rnn \ 20 | -rnn_size 200 \ 21 | -gpu_ranks 0 \ 22 | -valid_steps 100 \ 23 | -copy_attn \ 24 | -reuse_copy_attn \ 25 | -log_file logs/${name}.txt \ 26 | -save_config logs/${name}.txt 27 | 28 | 29 | CUDA_VISIBLE_DEVICES=${cuda} python train.py -save_model ckpt/${name} \ 30 | -data ${data} \ 31 | -batch_size 1 \ 32 | -learning_rate 0.001 \ 33 | -share_embeddings \ 34 | -pre_word_vecs_enc ${data_prefix}/embeddings.enc.pt \ 35 | -pre_word_vecs_dec ${data_prefix}/embeddings.dec.pt \ 36 | -save_checkpoint_steps 100 \ 37 | -seed 777 \ 38 | -optim adam \ 39 | -max_grad_norm 2 \ 40 | -report_every 100 \ 41 | -word_vec_size 300 \ 42 | -encoder_type rnn \ 43 | -rnn_size 200 \ 44 | -gpu_ranks 0 \ 45 | -valid_steps 100 \ 46 | -copy_attn \ 47 | -reuse_copy_attn \ 48 | -log_file logs/${name}.txt -------------------------------------------------------------------------------- /pgn/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from onmt.bin.train import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /pgn/translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from onmt.bin.translate import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /pic/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcfcode/PLM_annotator/2018eb887b2b2badd138ecfc33831582f29cd37b/pic/main.png --------------------------------------------------------------------------------