├── .github └── PULL_REQUEST_TEMPLATE.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── LICENSE-FACTCC.txt ├── LICENSE-XSUM-HAL.txt ├── create_dataset.sh ├── generate_consolidated_dataset.py ├── generate_dataset_factcc.py ├── generate_dataset_frank.py ├── generate_dataset_maynez.py ├── generate_dataset_qags.py ├── mappings │ ├── dev_ids.json │ └── test_ids.json ├── preprocess │ ├── 00_convert_edge_level_dataset.py │ ├── 01_generate_sents_for_graphs.py │ ├── 02.1.1_align_amrs.py │ ├── 02.1_get_amr_data.py │ ├── 02.2_create_amr_json_eval.py │ ├── 02.2_create_amr_json_nopar.py │ ├── 02.2_create_amr_json_nopar_edge_level.py │ ├── 02_generate_amrs.py │ ├── 3_check_files.py │ ├── LexRank.py │ ├── augmentation_ops.py │ ├── create_envs_preprocess.sh │ ├── preprocess_evaluate.py │ ├── process_dataset_for_edge_model.sh │ ├── process_dataset_for_model.sh │ ├── requirements-preprocess.txt │ ├── scripts │ │ └── predict_amrs_from_plaintext.py │ ├── update_envs_preprocess.sh │ └── utils.py └── utils.py ├── images ├── example.png └── factgraph.png ├── requirements.txt └── src ├── download_pretrained_adapters.sh ├── download_trained_models.sh ├── evaluate.py ├── evaluate.sh ├── main.py ├── main_edgelevel.py ├── models.py ├── predict.sh ├── predict_edgelevel.sh ├── preprocess.py ├── train.sh ├── train_edgelevel.sh ├── transformers ├── __init__.py ├── activations.py ├── activations_tf.py ├── benchmark │ ├── __init__.py │ ├── benchmark.py │ ├── benchmark_args.py │ ├── benchmark_args_tf.py │ ├── benchmark_args_utils.py │ ├── benchmark_tf.py │ └── benchmark_utils.py ├── commands │ ├── __init__.py │ ├── add_new_model.py │ ├── convert.py │ ├── download.py │ ├── env.py │ ├── lfs.py │ ├── run.py │ ├── serving.py │ ├── train.py │ ├── transformers_cli.py │ └── user.py ├── configuration_utils.py ├── convert_graph_to_onnx.py ├── convert_pytorch_checkpoint_to_tf2.py ├── convert_slow_tokenizer.py ├── convert_slow_tokenizers_checkpoints_to_fast.py ├── convert_tf_hub_seq_to_seq_bert_to_pytorch.py ├── data │ ├── __init__.py │ ├── data_collator.py │ ├── datasets │ │ ├── __init__.py │ │ ├── glue.py │ │ ├── language_modeling.py │ │ └── squad.py │ ├── metrics │ │ ├── __init__.py │ │ └── squad_metrics.py │ ├── processors │ │ ├── __init__.py │ │ ├── glue.py │ │ ├── squad.py │ │ ├── utils.py │ │ └── xnli.py │ └── test_generation_utils.py ├── debug_utils.py ├── dependency_versions_check.py ├── dependency_versions_table.py ├── feature_extraction_sequence_utils.py ├── feature_extraction_utils.py ├── file_utils.py ├── generation_beam_search.py ├── generation_logits_process.py ├── generation_stopping_criteria.py ├── generation_tf_utils.py ├── generation_utils.py ├── hf_api.py ├── hf_argparser.py ├── image_utils.py ├── integrations.py ├── modelcard.py ├── modeling_flax_outputs.py ├── modeling_flax_pytorch_utils.py ├── modeling_flax_utils.py ├── modeling_outputs.py ├── modeling_tf_outputs.py ├── modeling_tf_pytorch_utils.py ├── modeling_tf_utils.py ├── modeling_utils.py ├── models │ ├── __init__.py │ ├── albert │ │ ├── __init__.py │ │ ├── configuration_albert.py │ │ ├── convert_albert_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_albert.py │ │ ├── modeling_tf_albert.py │ │ ├── tokenization_albert.py │ │ └── tokenization_albert_fast.py │ ├── auto │ │ ├── __init__.py │ │ ├── auto_factory.py │ │ ├── configuration_auto.py │ │ ├── feature_extraction_auto.py │ │ ├── modeling_auto.py │ │ ├── modeling_flax_auto.py │ │ ├── modeling_tf_auto.py │ │ └── tokenization_auto.py │ ├── bart │ │ ├── __init__.py │ │ ├── configuration_bart.py │ │ ├── convert_bart_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_bart.py │ │ ├── modeling_tf_bart.py │ │ ├── tokenization_bart.py │ │ └── tokenization_bart_fast.py │ ├── barthez │ │ ├── __init__.py │ │ ├── tokenization_barthez.py │ │ └── tokenization_barthez_fast.py │ ├── bert │ │ ├── __init__.py │ │ ├── configuration_bert.py │ │ ├── convert_bert_original_tf2_checkpoint_to_pytorch.py │ │ ├── convert_bert_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_bert_pytorch_checkpoint_to_original_tf.py │ │ ├── modeling_bert.py │ │ ├── modeling_flax_bert.py │ │ ├── modeling_tf_bert.py │ │ ├── tokenization_bert.py │ │ └── tokenization_bert_fast.py │ ├── bert_adapter │ │ ├── __init__.py │ │ ├── configuration_bert.py │ │ ├── convert_bert_original_tf2_checkpoint_to_pytorch.py │ │ ├── convert_bert_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_bert_pytorch_checkpoint_to_original_tf.py │ │ ├── modeling_bert.py │ │ ├── modeling_flax_bert.py │ │ ├── modeling_tf_bert.py │ │ ├── tokenization_bert.py │ │ └── tokenization_bert_fast.py │ ├── bert_generation │ │ ├── __init__.py │ │ ├── configuration_bert_generation.py │ │ ├── modeling_bert_generation.py │ │ └── tokenization_bert_generation.py │ ├── bert_japanese │ │ ├── __init__.py │ │ └── tokenization_bert_japanese.py │ ├── bertweet │ │ ├── __init__.py │ │ └── tokenization_bertweet.py │ ├── big_bird │ │ ├── __init__.py │ │ ├── configuration_big_bird.py │ │ ├── convert_bigbird_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_big_bird.py │ │ ├── tokenization_big_bird.py │ │ └── tokenization_big_bird_fast.py │ ├── bigbird_pegasus │ │ ├── __init__.py │ │ ├── configuration_bigbird_pegasus.py │ │ ├── convert_bigbird_pegasus_tf_to_pytorch.py │ │ └── modeling_bigbird_pegasus.py │ ├── blenderbot │ │ ├── __init__.py │ │ ├── configuration_blenderbot.py │ │ ├── convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_blenderbot.py │ │ ├── modeling_tf_blenderbot.py │ │ └── tokenization_blenderbot.py │ ├── blenderbot_small │ │ ├── __init__.py │ │ ├── configuration_blenderbot_small.py │ │ ├── modeling_blenderbot_small.py │ │ ├── modeling_tf_blenderbot_small.py │ │ ├── tokenization_blenderbot_small.py │ │ └── tokenization_blenderbot_small_fast.py │ ├── camembert │ │ ├── __init__.py │ │ ├── configuration_camembert.py │ │ ├── modeling_camembert.py │ │ ├── modeling_tf_camembert.py │ │ ├── tokenization_camembert.py │ │ └── tokenization_camembert_fast.py │ ├── clip │ │ ├── __init__.py │ │ ├── configuration_clip.py │ │ ├── convert_clip_original_pytorch_to_hf.py │ │ ├── feature_extraction_clip.py │ │ ├── modeling_clip.py │ │ ├── processing_clip.py │ │ ├── tokenization_clip.py │ │ └── tokenization_clip_fast.py │ ├── convbert │ │ ├── __init__.py │ │ ├── configuration_convbert.py │ │ ├── convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py │ │ ├── modeling_convbert.py │ │ ├── modeling_tf_convbert.py │ │ ├── tokenization_convbert.py │ │ └── tokenization_convbert_fast.py │ ├── cpm │ │ ├── __init__.py │ │ └── tokenization_cpm.py │ ├── ctrl │ │ ├── __init__.py │ │ ├── configuration_ctrl.py │ │ ├── modeling_ctrl.py │ │ ├── modeling_tf_ctrl.py │ │ └── tokenization_ctrl.py │ ├── deberta │ │ ├── __init__.py │ │ ├── configuration_deberta.py │ │ ├── modeling_deberta.py │ │ ├── tokenization_deberta.py │ │ └── tokenization_deberta_fast.py │ ├── deberta_v2 │ │ ├── __init__.py │ │ ├── configuration_deberta_v2.py │ │ ├── modeling_deberta_v2.py │ │ └── tokenization_deberta_v2.py │ ├── deit │ │ ├── __init__.py │ │ ├── configuration_deit.py │ │ ├── convert_deit_timm_to_pytorch.py │ │ ├── feature_extraction_deit.py │ │ └── modeling_deit.py │ ├── dialogpt │ │ ├── __init__.py │ │ └── convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py │ ├── distilbert │ │ ├── __init__.py │ │ ├── configuration_distilbert.py │ │ ├── modeling_distilbert.py │ │ ├── modeling_tf_distilbert.py │ │ ├── tokenization_distilbert.py │ │ └── tokenization_distilbert_fast.py │ ├── dpr │ │ ├── __init__.py │ │ ├── configuration_dpr.py │ │ ├── convert_dpr_original_checkpoint_to_pytorch.py │ │ ├── modeling_dpr.py │ │ ├── modeling_tf_dpr.py │ │ ├── tokenization_dpr.py │ │ └── tokenization_dpr_fast.py │ ├── electra │ │ ├── __init__.py │ │ ├── configuration_electra.py │ │ ├── convert_electra_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_electra.py │ │ ├── modeling_flax_electra.py │ │ ├── modeling_tf_electra.py │ │ ├── tokenization_electra.py │ │ └── tokenization_electra_fast.py │ ├── electra_adapter │ │ ├── __init__.py │ │ └── modeling_electra.py │ ├── encoder_decoder │ │ ├── __init__.py │ │ ├── configuration_encoder_decoder.py │ │ └── modeling_encoder_decoder.py │ ├── flaubert │ │ ├── __init__.py │ │ ├── configuration_flaubert.py │ │ ├── modeling_flaubert.py │ │ ├── modeling_tf_flaubert.py │ │ └── tokenization_flaubert.py │ ├── fsmt │ │ ├── __init__.py │ │ ├── configuration_fsmt.py │ │ ├── convert_fsmt_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_fsmt.py │ │ └── tokenization_fsmt.py │ ├── funnel │ │ ├── __init__.py │ │ ├── configuration_funnel.py │ │ ├── convert_funnel_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_funnel.py │ │ ├── modeling_tf_funnel.py │ │ ├── tokenization_funnel.py │ │ └── tokenization_funnel_fast.py │ ├── gpt2 │ │ ├── __init__.py │ │ ├── configuration_gpt2.py │ │ ├── convert_gpt2_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_gpt2.py │ │ ├── modeling_tf_gpt2.py │ │ ├── tokenization_gpt2.py │ │ └── tokenization_gpt2_fast.py │ ├── gpt_neo │ │ ├── __init__.py │ │ ├── configuration_gpt_neo.py │ │ ├── convert_gpt_neo_mesh_tf_to_pytorch.py │ │ └── modeling_gpt_neo.py │ ├── herbert │ │ ├── __init__.py │ │ ├── tokenization_herbert.py │ │ └── tokenization_herbert_fast.py │ ├── ibert │ │ ├── __init__.py │ │ ├── configuration_ibert.py │ │ ├── modeling_ibert.py │ │ └── quant_modules.py │ ├── layoutlm │ │ ├── __init__.py │ │ ├── configuration_layoutlm.py │ │ ├── modeling_layoutlm.py │ │ ├── modeling_tf_layoutlm.py │ │ ├── tokenization_layoutlm.py │ │ └── tokenization_layoutlm_fast.py │ ├── led │ │ ├── __init__.py │ │ ├── configuration_led.py │ │ ├── modeling_led.py │ │ ├── modeling_tf_led.py │ │ ├── tokenization_led.py │ │ └── tokenization_led_fast.py │ ├── longformer │ │ ├── __init__.py │ │ ├── configuration_longformer.py │ │ ├── convert_longformer_original_pytorch_lightning_to_pytorch.py │ │ ├── modeling_longformer.py │ │ ├── modeling_tf_longformer.py │ │ ├── tokenization_longformer.py │ │ └── tokenization_longformer_fast.py │ ├── luke │ │ ├── __init__.py │ │ ├── configuration_luke.py │ │ ├── convert_luke_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_luke.py │ │ └── tokenization_luke.py │ ├── lxmert │ │ ├── __init__.py │ │ ├── configuration_lxmert.py │ │ ├── convert_lxmert_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_lxmert.py │ │ ├── modeling_tf_lxmert.py │ │ ├── tokenization_lxmert.py │ │ └── tokenization_lxmert_fast.py │ ├── m2m_100 │ │ ├── __init__.py │ │ ├── configuration_m2m_100.py │ │ ├── convert_m2m100_original_checkpoint_to_pytorch.py │ │ ├── modeling_m2m_100.py │ │ └── tokenization_m2m_100.py │ ├── marian │ │ ├── __init__.py │ │ ├── configuration_marian.py │ │ ├── convert_marian_tatoeba_to_pytorch.py │ │ ├── convert_marian_to_pytorch.py │ │ ├── modeling_marian.py │ │ ├── modeling_tf_marian.py │ │ └── tokenization_marian.py │ ├── mbart │ │ ├── __init__.py │ │ ├── configuration_mbart.py │ │ ├── convert_mbart_original_checkpoint_to_pytorch.py │ │ ├── modeling_mbart.py │ │ ├── modeling_tf_mbart.py │ │ ├── tokenization_mbart.py │ │ ├── tokenization_mbart50.py │ │ ├── tokenization_mbart50_fast.py │ │ └── tokenization_mbart_fast.py │ ├── megatron_bert │ │ ├── __init__.py │ │ ├── configuration_megatron_bert.py │ │ ├── convert_megatron_bert_checkpoint.py │ │ └── modeling_megatron_bert.py │ ├── mmbt │ │ ├── __init__.py │ │ ├── configuration_mmbt.py │ │ └── modeling_mmbt.py │ ├── mobilebert │ │ ├── __init__.py │ │ ├── configuration_mobilebert.py │ │ ├── convert_mobilebert_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_mobilebert.py │ │ ├── modeling_tf_mobilebert.py │ │ ├── tokenization_mobilebert.py │ │ └── tokenization_mobilebert_fast.py │ ├── mpnet │ │ ├── __init__.py │ │ ├── configuration_mpnet.py │ │ ├── modeling_mpnet.py │ │ ├── modeling_tf_mpnet.py │ │ ├── tokenization_mpnet.py │ │ └── tokenization_mpnet_fast.py │ ├── mt5 │ │ ├── __init__.py │ │ ├── configuration_mt5.py │ │ ├── modeling_mt5.py │ │ └── modeling_tf_mt5.py │ ├── openai │ │ ├── __init__.py │ │ ├── configuration_openai.py │ │ ├── convert_openai_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_openai.py │ │ ├── modeling_tf_openai.py │ │ ├── tokenization_openai.py │ │ └── tokenization_openai_fast.py │ ├── pegasus │ │ ├── __init__.py │ │ ├── configuration_pegasus.py │ │ ├── convert_pegasus_tf_to_pytorch.py │ │ ├── modeling_pegasus.py │ │ ├── modeling_tf_pegasus.py │ │ ├── tokenization_pegasus.py │ │ └── tokenization_pegasus_fast.py │ ├── phobert │ │ ├── __init__.py │ │ └── tokenization_phobert.py │ ├── prophetnet │ │ ├── __init__.py │ │ ├── configuration_prophetnet.py │ │ ├── convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_prophetnet.py │ │ └── tokenization_prophetnet.py │ ├── rag │ │ ├── __init__.py │ │ ├── configuration_rag.py │ │ ├── modeling_rag.py │ │ ├── modeling_tf_rag.py │ │ ├── retrieval_rag.py │ │ └── tokenization_rag.py │ ├── reformer │ │ ├── __init__.py │ │ ├── configuration_reformer.py │ │ ├── convert_reformer_trax_checkpoint_to_pytorch.py │ │ ├── modeling_reformer.py │ │ ├── tokenization_reformer.py │ │ └── tokenization_reformer_fast.py │ ├── retribert │ │ ├── __init__.py │ │ ├── configuration_retribert.py │ │ ├── modeling_retribert.py │ │ ├── tokenization_retribert.py │ │ └── tokenization_retribert_fast.py │ ├── roberta │ │ ├── __init__.py │ │ ├── configuration_roberta.py │ │ ├── convert_roberta_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_flax_roberta.py │ │ ├── modeling_roberta.py │ │ ├── modeling_tf_roberta.py │ │ ├── tokenization_roberta.py │ │ └── tokenization_roberta_fast.py │ ├── speech_to_text │ │ ├── __init__.py │ │ ├── configuration_speech_to_text.py │ │ ├── convert_s2t_fairseq_to_tfms.py │ │ ├── feature_extraction_speech_to_text.py │ │ ├── modeling_speech_to_text.py │ │ ├── processing_speech_to_text.py │ │ └── tokenization_speech_to_text.py │ ├── squeezebert │ │ ├── __init__.py │ │ ├── configuration_squeezebert.py │ │ ├── modeling_squeezebert.py │ │ ├── tokenization_squeezebert.py │ │ └── tokenization_squeezebert_fast.py │ ├── t5 │ │ ├── __init__.py │ │ ├── configuration_t5.py │ │ ├── convert_t5_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_t5.py │ │ ├── modeling_tf_t5.py │ │ ├── tokenization_t5.py │ │ └── tokenization_t5_fast.py │ ├── tapas │ │ ├── __init__.py │ │ ├── configuration_tapas.py │ │ ├── convert_tapas_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_tapas.py │ │ └── tokenization_tapas.py │ ├── transfo_xl │ │ ├── __init__.py │ │ ├── configuration_transfo_xl.py │ │ ├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_tf_transfo_xl.py │ │ ├── modeling_tf_transfo_xl_utilities.py │ │ ├── modeling_transfo_xl.py │ │ ├── modeling_transfo_xl_utilities.py │ │ └── tokenization_transfo_xl.py │ ├── vit │ │ ├── __init__.py │ │ ├── configuration_vit.py │ │ ├── convert_vit_timm_to_pytorch.py │ │ ├── feature_extraction_vit.py │ │ └── modeling_vit.py │ ├── wav2vec2 │ │ ├── __init__.py │ │ ├── configuration_wav2vec2.py │ │ ├── convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py │ │ ├── feature_extraction_wav2vec2.py │ │ ├── modeling_wav2vec2.py │ │ ├── processing_wav2vec2.py │ │ └── tokenization_wav2vec2.py │ ├── xlm │ │ ├── __init__.py │ │ ├── configuration_xlm.py │ │ ├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_tf_xlm.py │ │ ├── modeling_xlm.py │ │ └── tokenization_xlm.py │ ├── xlm_prophetnet │ │ ├── __init__.py │ │ ├── configuration_xlm_prophetnet.py │ │ ├── modeling_xlm_prophetnet.py │ │ └── tokenization_xlm_prophetnet.py │ ├── xlm_roberta │ │ ├── __init__.py │ │ ├── configuration_xlm_roberta.py │ │ ├── modeling_tf_xlm_roberta.py │ │ ├── modeling_xlm_roberta.py │ │ ├── tokenization_xlm_roberta.py │ │ └── tokenization_xlm_roberta_fast.py │ └── xlnet │ │ ├── __init__.py │ │ ├── configuration_xlnet.py │ │ ├── convert_xlnet_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_tf_xlnet.py │ │ ├── modeling_xlnet.py │ │ ├── tokenization_xlnet.py │ │ └── tokenization_xlnet_fast.py ├── optimization.py ├── optimization_tf.py ├── pipelines │ ├── __init__.py │ ├── automatic_speech_recognition.py │ ├── base.py │ ├── conversational.py │ ├── feature_extraction.py │ ├── fill_mask.py │ ├── image_classification.py │ ├── question_answering.py │ ├── table_question_answering.py │ ├── text2text_generation.py │ ├── text_classification.py │ ├── text_generation.py │ ├── token_classification.py │ └── zero_shot_classification.py ├── sagemaker │ ├── __init__.py │ ├── trainer_sm.py │ └── training_args_sm.py ├── testing_utils.py ├── tokenization_utils.py ├── tokenization_utils_base.py ├── tokenization_utils_fast.py ├── trainer.py ├── trainer_callback.py ├── trainer_pt_utils.py ├── trainer_seq2seq.py ├── trainer_tf.py ├── trainer_utils.py ├── training_args.py ├── training_args_seq2seq.py ├── training_args_tf.py └── utils │ ├── __init__.py │ ├── dummy_flax_objects.py │ ├── dummy_pt_objects.py │ ├── dummy_sentencepiece_and_speech_objects.py │ ├── dummy_sentencepiece_and_tokenizers_objects.py │ ├── dummy_sentencepiece_objects.py │ ├── dummy_speech_objects.py │ ├── dummy_tf_objects.py │ ├── dummy_tokenizers_objects.py │ ├── dummy_vision_objects.py │ ├── hp_naming.py │ ├── imagenet_classes.py │ ├── logging.py │ ├── model_parallel_utils.py │ ├── modeling_auto_mapping.py │ ├── notebook.py │ ├── sentencepiece_model_pb2.py │ └── versions.py ├── utils.py └── utils_evaluate.py /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /data/LICENSE-FACTCC.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /data/create_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE_CNNDM_QAGS=https://github.com/W4ngatang/qags/raw/master/data/mturk_cnndm.jsonl 4 | FILE_XSUM_QAGS=https://github.com/W4ngatang/qags/raw/master/data/mturk_xsum.jsonl 5 | 6 | FILE_FACTCC=https://storage.googleapis.com/sfr-factcc-data-research/unpaired_annotated_data.tar.gz 7 | 8 | FILE_MAYNEZ=https://github.com/google-research-datasets/xsum_hallucination_annotations/raw/master/hallucination_annotations_xsum_summaries.csv 9 | 10 | FILE_FRANK=https://github.com/artidoro/frank/raw/main/data/human_annotations_sentence.json 11 | 12 | rm -rf qags 13 | mkdir -p qags 14 | wget ${FILE_CNNDM_QAGS} -P qags 15 | wget ${FILE_XSUM_QAGS} -P qags 16 | 17 | rm -rf factcc 18 | mkdir -p factcc 19 | wget ${FILE_FACTCC} -P factcc 20 | tar zxvf factcc/unpaired_annotated_data.tar.gz -C factcc/ 21 | 22 | rm -rf maynez 23 | mkdir -p maynez 24 | wget ${FILE_MAYNEZ} -P maynez 25 | 26 | rm -rf frank 27 | mkdir -p frank 28 | wget ${FILE_FRANK} -P frank 29 | 30 | 31 | python generate_dataset_qags.py qags/mturk_cnndm.jsonl qags/mturk_xsum.jsonl qags/processed.json 32 | 33 | python generate_dataset_factcc.py factcc/unpaired_annotated_data/ factcc/processed.json 34 | 35 | python generate_dataset_maynez.py maynez/hallucination_annotations_xsum_summaries.csv maynez/processed.json 36 | 37 | python generate_dataset_frank.py frank/human_annotations_sentence.json frank/processed.json 38 | 39 | rm -rf processed_dataset 40 | mkdir -p processed_dataset 41 | 42 | python generate_consolidated_dataset.py 43 | -------------------------------------------------------------------------------- /data/generate_dataset_factcc.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from datasets import load_dataset 4 | import sys 5 | 6 | folder_dataset = sys.argv[1] 7 | 8 | file_val = folder_dataset + '/val/data-dev.jsonl' 9 | file_test = folder_dataset + '/test/data-dev.jsonl' 10 | file_processed = sys.argv[2] 11 | 12 | dataset_hf = load_dataset("cnn_dailymail", "3.0.0") 13 | 14 | data_hf = {} 15 | for d in dataset_hf["validation"]: 16 | data_hf[d['id']] = d['article'] 17 | for d in dataset_hf["test"]: 18 | data_hf[d['id']] = d['article'] 19 | 20 | 21 | with open(file_val) as fd: 22 | dataset_val = [json.loads(line) for line in fd] 23 | 24 | with open(file_test) as fd: 25 | dataset_test = [json.loads(line) for line in fd] 26 | 27 | dataset = dataset_val + dataset_test 28 | 29 | for idx, example in tqdm(enumerate(dataset)): 30 | try: 31 | example['text'] = data_hf[example['id'].split("/")[-1]] 32 | except: 33 | example['text'] = data_hf[example['id'].split("-")[-1]] 34 | example['id_order'] = idx 35 | example['id'] = example['id'] + "_" + str(idx) 36 | 37 | with open(file_processed, 'w') as fd: 38 | for example in dataset: 39 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 40 | 41 | -------------------------------------------------------------------------------- /data/generate_dataset_frank.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import sys 3 | import unidecode 4 | from utils import * 5 | 6 | def get_dict(data): 7 | dict_data = {} 8 | duplicates = {} 9 | count = 0 10 | for d in data: 11 | key = d['summary'] + '-sum-art-' + d['article'] 12 | if key in dict_data: 13 | count += 1 14 | key_dup = d['summary'] + '-sum-art-' + d['article'] 15 | duplicates[key_dup] = d 16 | dict_data[key] = d 17 | return dict_data, duplicates 18 | 19 | 20 | file = sys.argv[1] 21 | file_output = sys.argv[2] 22 | 23 | with open(file) as file: 24 | frank = json.load(file) 25 | 26 | dataset = load_dataset("cnn_dailymail", '3.0.0') 27 | hash_cnndm = set() 28 | 29 | for d in dataset['train']: 30 | hash_cnndm.add(d['id']) 31 | for d in dataset['test']: 32 | hash_cnndm.add(d['id']) 33 | for d in dataset['validation']: 34 | hash_cnndm.add(d['id']) 35 | 36 | dataset = load_dataset("xsum") 37 | hash_xsum = set() 38 | 39 | for d in dataset['train']: 40 | hash_xsum.add(d['id']) 41 | for d in dataset['test']: 42 | hash_xsum.add(d['id']) 43 | for d in dataset['validation']: 44 | hash_xsum.add(d['id']) 45 | 46 | 47 | labels_cont_cnndm = [] 48 | labels_cont_xsum = [] 49 | new_data = [] 50 | labels_cont = [] 51 | 52 | for idx, example in enumerate(frank): 53 | assert len(example['summary_sentences']) == len(example['summary_sentences_annotations']) 54 | 55 | for s, sa in zip(example['summary_sentences'], example['summary_sentences_annotations']): 56 | new_example = {} 57 | new_example['summary'] = unidecode.unidecode(s) 58 | new_example['article'] = unidecode.unidecode(example['article']) 59 | new_example['id'] = example['hash'] + "_" + example['model_name'] + "_" + str(len(new_data)) 60 | new_example['id_order'] = len(new_data) 61 | new_example['model_name'] = example['model_name'] 62 | new_example['split'] = example['split'] 63 | 64 | new_example['source'] = 'frank' 65 | 66 | if example['hash'] in hash_cnndm: 67 | new_example['domain'] = 'cnndm' 68 | elif example['hash'] in hash_xsum: 69 | new_example['domain'] = 'xsum' 70 | else: 71 | print(d['hash']) 72 | print('error') 73 | exit() 74 | 75 | labels = [] 76 | for k, v in sa.items(): 77 | if v[0] == 'NoE': 78 | new_label = 'CORRECT' 79 | else: 80 | new_label = 'INCORRECT' 81 | labels.append(new_label) 82 | 83 | label = most_common(labels) 84 | 85 | new_example['label'] = label 86 | 87 | new_data.append(new_example) 88 | labels_cont.append(label) 89 | if new_example['domain'] == 'cnndm': 90 | labels_cont_cnndm.append(new_example['label']) 91 | if new_example['domain'] == 'xsum': 92 | labels_cont_xsum.append(new_example['label']) 93 | 94 | 95 | save_data(new_data, file_output) 96 | -------------------------------------------------------------------------------- /data/generate_dataset_maynez.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import sys 3 | from datasets import load_dataset 4 | 5 | 6 | def consolidate_annotations(data): 7 | 8 | dataset_hf = load_dataset("xsum") 9 | 10 | data_hf = {} 11 | for example in dataset_hf["validation"]: 12 | data_hf[example['id']] = example['document'] 13 | for example in dataset_hf["test"]: 14 | data_hf[example['id']] = example['document'] 15 | 16 | dict_data = {} 17 | 18 | for example in data[1:]: 19 | id_sum = str(example[0]) + "_" + str(example[1]) 20 | summary = example[2] 21 | hal = example[3] 22 | if hal == 'extrinsic' or hal == 'intrinsic': 23 | hal = 'INCORRECT' 24 | else: 25 | hal = 'CORRECT' 26 | if id_sum not in dict_data: 27 | dict_data[id_sum] = {} 28 | dict_data[id_sum]['summary'] = summary 29 | if 'labels' not in dict_data[id_sum]: 30 | dict_data[id_sum]['labels'] = [] 31 | 32 | dict_data[id_sum]['labels'].append(hal) 33 | label = most_common(dict_data[id_sum]['labels']) 34 | dict_data[id_sum]['label'] = label 35 | dict_data[id_sum]['article'] = data_hf[id_sum.split("_")[0]] 36 | 37 | consolidated_data = [] 38 | for idx, k in enumerate(dict_data.keys()): 39 | del dict_data[k]['labels'] 40 | dict_data[k]['id_order'] = idx 41 | dict_data[k]['id'] = k 42 | consolidated_data.append(dict_data[k]) 43 | 44 | return consolidated_data 45 | 46 | 47 | file = sys.argv[1] 48 | output_file = sys.argv[2] 49 | data = read_csv(file) 50 | data_maynez = consolidate_annotations(data) 51 | 52 | save_data(data_maynez, output_file) 53 | -------------------------------------------------------------------------------- /data/generate_dataset_qags.py: -------------------------------------------------------------------------------- 1 | import unidecode 2 | from utils import * 3 | import sys 4 | 5 | file = sys.argv[1] 6 | cnndm = load_source_docs(file) 7 | 8 | processed_data = [] 9 | labels_cont = [] 10 | id_order = 0 11 | 12 | for idx, example in enumerate(cnndm): 13 | 14 | for s in example['summary_sentences']: 15 | new_example = {} 16 | new_example['summary'] = unidecode.unidecode(s['sentence']) 17 | new_example['article'] = unidecode.unidecode(example['article']) 18 | new_example['domain'] = 'cnndm' 19 | new_example['source'] = 'qags' 20 | new_example['id'] = 'cnndm_qags_' + str(id_order) 21 | new_example['id_order'] = id_order 22 | id_order += 1 23 | 24 | reps = [r['response'] for r in s['responses']] 25 | label = most_common(reps) 26 | if label == 'yes': 27 | new_example['label'] = 'CORRECT' 28 | elif label == 'no': 29 | new_example['label'] = 'INCORRECT' 30 | else: 31 | print('error') 32 | exit() 33 | processed_data.append(new_example) 34 | labels_cont.append(new_example['label']) 35 | 36 | 37 | file = sys.argv[2] 38 | xsum = load_source_docs(file) 39 | 40 | 41 | labels_cont = [] 42 | for idx, example in enumerate(xsum): 43 | 44 | for s in example['summary_sentences']: 45 | new_example = {} 46 | new_example['summary'] = unidecode.unidecode(s['sentence']) 47 | new_example['article'] = unidecode.unidecode(example['article']) 48 | new_example['domain'] = 'xsum' 49 | new_example['source'] = 'qags' 50 | new_example['id'] = 'xsum_qags_' + str(id_order) 51 | new_example['id_order'] = id_order 52 | id_order += 1 53 | 54 | reps = [r['response'] for r in s['responses']] 55 | label = most_common(reps) 56 | if label == 'yes': 57 | new_example['label'] = 'CORRECT' 58 | elif label == 'no': 59 | new_example['label'] = 'INCORRECT' 60 | else: 61 | print('error') 62 | exit() 63 | processed_data.append(new_example) 64 | labels_cont.append(new_example['label']) 65 | 66 | 67 | output_file = sys.argv[3] 68 | save_data(processed_data, output_file) 69 | -------------------------------------------------------------------------------- /data/preprocess/00_convert_edge_level_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import math 3 | import numbers 4 | import json 5 | import sys 6 | from datasets import load_dataset 7 | 8 | labels = {0: 'INCORRECT', 1: 'CORRECT'} 9 | 10 | 11 | def save_data(data, output_file): 12 | with open(output_file, "w", encoding="utf-8") as fd: 13 | for example in data: 14 | example = dict(example) 15 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 16 | 17 | 18 | dataset_hf = load_dataset("xsum") 19 | data_hf = {} 20 | for example in dataset_hf["test"]: 21 | data_hf[example['id']] = example['document'] 22 | 23 | def process_csv(csv_file, json_file): 24 | data = pandas.read_csv(csv_file, delimiter='\t') 25 | 26 | processed_data = [] 27 | 28 | for index, row in data.iterrows(): 29 | example = {} 30 | example['id'] = row['id'] 31 | example['summary'] = row['context'] 32 | example['label'] = labels[row['sentlabel']] 33 | 34 | id_xsum = example['id'].split("_")[0] 35 | 36 | example['article'] = data_hf[id_xsum] 37 | 38 | hals = [] 39 | for i in range(20): 40 | idx_words = row['dep_idx'+str(i)] 41 | 42 | if isinstance(idx_words, numbers.Number) and math.isnan(idx_words): 43 | continue 44 | 45 | words = row['dep_words' + str(i)] 46 | label = int(row['dep_label' + str(i)]) 47 | rel = row['dep'+ str(i)] 48 | 49 | hals.append((idx_words, words, label, rel)) 50 | 51 | example['hallucinations'] = json.dumps(hals) 52 | processed_data.append(example) 53 | save_data(processed_data, json_file) 54 | 55 | 56 | folder_edge_level_data = sys.argv[1] 57 | folder_preprocessed_edge_level_data = sys.argv[2] 58 | 59 | process_csv(folder_edge_level_data + '/train.tsv', folder_preprocessed_edge_level_data + '/train.json') 60 | process_csv(folder_edge_level_data + '/test.tsv', folder_preprocessed_edge_level_data + '/test.json') 61 | 62 | 63 | -------------------------------------------------------------------------------- /data/preprocess/02.1_get_amr_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from tqdm import tqdm 4 | from concurrent.futures import ProcessPoolExecutor, as_completed 5 | 6 | def load_source_docs(file_path): 7 | with open(file_path, encoding="utf-8") as f: 8 | data = [json.loads(line) for line in f] 9 | return data 10 | 11 | 12 | def save_data(data, output_file): 13 | with open(output_file, "w", encoding="utf-8") as fd: 14 | for example in data: 15 | fd.write(example + "\n") 16 | 17 | 18 | import sys 19 | import os 20 | file = sys.argv[1] 21 | 22 | cnndm = load_source_docs(file) 23 | 24 | summaries = [] 25 | docs = {} 26 | negative = [] 27 | ids_neg = [] 28 | ids_pos = [] 29 | claims = [] 30 | idx_claims = [] 31 | for idx, d in enumerate(tqdm(cnndm)): 32 | for idx_sent, sent in enumerate(json.loads(d['sentences'])): 33 | sent = sent[0] 34 | if sent not in docs: 35 | docs[sent] = [] 36 | docs[sent].append((d['id'], idx_sent)) 37 | 38 | sent = d['summary'] 39 | if sent not in docs: 40 | docs[sent] = [] 41 | docs[sent].append((d['id'], idx_sent)) 42 | 43 | 44 | sents = [] 45 | ids_docs = [] 46 | for sent in docs.keys(): 47 | sents.append(sent) 48 | id_line = set() 49 | for id_, idx in docs[sent]: 50 | id_line.add(str(id_)+'-'+str(idx)) 51 | id_line = ' '.join(id_line) 52 | ids_docs.append(id_line) 53 | 54 | assert len(sents) == len(ids_docs) 55 | 56 | new_file = os.path.splitext(file)[0] + ".txt" 57 | save_data(sents, new_file) 58 | 59 | new_file = os.path.splitext(file)[0] + "-" + "idx_sents.txt" 60 | save_data(ids_docs, new_file) 61 | -------------------------------------------------------------------------------- /data/preprocess/02_generate_amrs.py: -------------------------------------------------------------------------------- 1 | import amrlib 2 | import json 3 | import sys 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | def load_source_docs(file_path, to_dict=False): 8 | with open(file_path, encoding="utf-8") as f: 9 | data = [json.loads(line) for line in f] 10 | if to_dict: 11 | data = {example["id"]: example for example in data} 12 | return data 13 | 14 | 15 | def save_data(data, file, name_suffix): 16 | output_file = os.path.splitext(file)[0] + "-" + name_suffix + ".json" 17 | 18 | with open(output_file, "w", encoding="utf-8") as fd: 19 | for example in data: 20 | example = dict(example) 21 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 22 | 23 | 24 | def chunks(lst, n): 25 | """Yield successive n-sized chunks from lst.""" 26 | for i in range(0, len(lst), n): 27 | yield lst[i:i + n] 28 | 29 | 30 | stog = amrlib.load_stog_model(model_dir='model_parse_xfm_bart_large-v0_1_0') 31 | 32 | input_file = sys.argv[1] 33 | amr_file = sys.argv[2] 34 | 35 | input_data = load_source_docs(input_file) 36 | 37 | sentences = [] 38 | for example in input_data: 39 | sentences.extend([sents[0] for sents in json.loads(example['sentences'])]) 40 | sentences.append(example['summary']) 41 | 42 | 43 | sentences = list(set(sentences)) 44 | print("Total of sentences:", len(sentences)) 45 | sentences = [list(sents) for sents in chunks(sentences, 20)] 46 | amr_file = open(amr_file, 'w') 47 | for sents in tqdm(sentences): 48 | try: 49 | graphs = stog.parse_sents(sents, add_metadata=True) 50 | 51 | for g in graphs: 52 | amr_file.write(g + "\n\n") 53 | except: 54 | print("Error during parsing.") 55 | 56 | -------------------------------------------------------------------------------- /data/preprocess/3_check_files.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def load_source_docs(file_path): 5 | with open(file_path, encoding="utf-8") as f: 6 | data = [json.loads(line) for line in f] 7 | return data 8 | 9 | 10 | 11 | # file_train = '/home/ubuntu/fact_project/code/test_models/qags/data_v2/created_dataset/train-all-sents-amr-5.json' 12 | # file_dev = '/home/ubuntu/fact_project/code/test_models/qags/data_v2/created_dataset/dev-all-sents-amr-5.json' 13 | # file_test = '/home/ubuntu/fact_project/code/test_models/qags/data_v2/created_dataset/test-all-sents-amr-5.json' 14 | 15 | file_train = '/home/ubuntu/fact_project/code/test_models/qags/data/created_dataset/train-sents-amr.json' 16 | file_dev = '/home/ubuntu/fact_project/code/test_models/qags/data/created_dataset/dev-sents-amr.json' 17 | file_test = '/home/ubuntu/fact_project/code/test_models/qags/data/created_dataset/test-sents-amr.json' 18 | 19 | # file_train = '/home/ubuntu/fact_project/code/test_models/qags/data_v2/created_dataset/train-5most-sents-amr.json' 20 | # file_dev = '/home/ubuntu/fact_project/code/test_models/qags/data_v2/created_dataset/dev-5most-sents-amr.json' 21 | # file_test = '/home/ubuntu/fact_project/code/test_models/qags/data_v2/created_dataset/test-5most-sents-amr.json' 22 | 23 | 24 | import collections 25 | 26 | def count_size(file): 27 | sizes = [] 28 | sizes_sent = [] 29 | empty = 0 30 | for d in file: 31 | sizes.append(len(d['graphs'])) 32 | sizes_sent.append(len(d['sentences'])) 33 | 34 | if not d['graph_claim']['amr_simple']: 35 | empty += 1 36 | 37 | d = collections.Counter(sizes) 38 | sizes = collections.OrderedDict(sorted(d.items())) 39 | 40 | d = collections.Counter(sizes_sent) 41 | sizes_sent = collections.OrderedDict(sorted(d.items())) 42 | print('sizes graph', sizes) 43 | print('sizes sent', sizes_sent) 44 | print('empty', empty) 45 | 46 | 47 | 48 | 49 | 50 | file_train = load_source_docs(file_train) 51 | count_size(file_train) 52 | file_dev = load_source_docs(file_dev) 53 | count_size(file_dev) 54 | file_test = load_source_docs(file_test) 55 | count_size(file_test) -------------------------------------------------------------------------------- /data/preprocess/create_envs_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | conda create -n preprocess-fatcgraph python=3.8 6 | conda activate preprocess-fatcgraph 7 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 8 | pip install -r requirements-preprocess.txt 9 | git clone https://github.com//ablodge/amr-utils 10 | pip install penman 11 | pip install ./amr-utils 12 | wget https://github.com/bjascob/amrlib-models/releases/download/parse_xfm_bart_large-v0_1_0/model_parse_xfm_bart_large-v0_1_0.tar.gz 13 | tar zxvf model_parse_xfm_bart_large-v0_1_0.tar.gz 14 | 15 | git clone https://github.com/clab/fast_align.git 16 | cd fast_align 17 | mkdir build 18 | cd build 19 | cmake .. 20 | make 21 | cd ../../ 22 | 23 | conda deactivate 24 | 25 | conda create -n spring python=3.8 26 | conda activate spring 27 | conda install pytorch==1.5.0 torchvision cudatoolkit=10.2 -c pytorch 28 | git clone https://github.com/SapienzaNLP/spring.git 29 | cd spring 30 | wget http://nlp.uniroma1.it/AMR/AMR3.parsing-1.0.tar.bz2 31 | tar -xf AMR3.parsing-1.0.tar.bz2 32 | pip install -r requirements.txt 33 | cp ../scripts/predict_amrs_from_plaintext.py bin/ 34 | pip install -e . 35 | 36 | conda deactivate 37 | 38 | 39 | -------------------------------------------------------------------------------- /data/preprocess/process_dataset_for_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | GPU_ID=$1 6 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 7 | 8 | NUMBER_GRAPHS=5 9 | 10 | conda deactivate 11 | conda activate preprocess-fatcgraph 12 | 13 | FOLDER=../processed_dataset 14 | 15 | ## CREATE SENTS 16 | 17 | FILE_DATA=${FOLDER}/train.json 18 | python 01_generate_sents_for_graphs.py ${FILE_DATA} ${NUMBER_GRAPHS} 19 | 20 | FILE_DATA=${FOLDER}/dev.json 21 | python 01_generate_sents_for_graphs.py ${FILE_DATA} ${NUMBER_GRAPHS} 22 | 23 | FILE_DATA=${FOLDER}/test.json 24 | python 01_generate_sents_for_graphs.py ${FILE_DATA} ${NUMBER_GRAPHS} 25 | 26 | 27 | ## EXTRACT SENTS 28 | 29 | FILE_DATA=${FOLDER}/train-sents.json 30 | python 02.1_get_amr_data.py ${FILE_DATA} 31 | 32 | FILE_DATA=${FOLDER}/dev-sents.json 33 | python 02.1_get_amr_data.py ${FILE_DATA} 34 | 35 | FILE_DATA=${FOLDER}/test-sents.json 36 | python 02.1_get_amr_data.py ${FILE_DATA} 37 | 38 | conda deactivate 39 | conda activate spring 40 | 41 | ### GENERATE AMRS 42 | FOLDER_SPRING=spring 43 | PATH_MODEL=${FOLDER_SPRING}/AMR3.parsing.pt 44 | 45 | 46 | FILE_VAL=${FOLDER}/train-sents.txt 47 | python -u ${FOLDER_SPRING}/bin/predict_amrs_from_plaintext.py --checkpoint ${PATH_MODEL} --texts ${FILE_VAL} --penman-linearization \ 48 | --use-pointer-tokens > ${FILE_VAL}.amr 49 | 50 | FILE_VAL=${FOLDER}/dev-sents.txt 51 | python -u ${FOLDER_SPRING}/bin/predict_amrs_from_plaintext.py --checkpoint ${PATH_MODEL} --texts ${FILE_VAL} --penman-linearization \ 52 | --use-pointer-tokens > ${FILE_VAL}.amr 53 | 54 | FILE_VAL=${FOLDER}/test-sents.txt 55 | python -u ${FOLDER_SPRING}/bin/predict_amrs_from_plaintext.py --checkpoint ${PATH_MODEL} --texts ${FILE_VAL} --penman-linearization \ 56 | --use-pointer-tokens > ${FILE_VAL}.amr 57 | 58 | conda deactivate 59 | 60 | ## GENERATE DATA FILES 61 | conda activate preprocess-fatcgraph 62 | 63 | AMR_DATA=${FOLDER}/train-sents.txt.amr 64 | FILE_DATA=${FOLDER}/train-sents.json 65 | python 02.2_create_amr_json_nopar.py ${FILE_DATA} ${AMR_DATA} ${NUMBER_GRAPHS} 66 | 67 | AMR_DATA=${FOLDER}/dev-sents.txt.amr 68 | FILE_DATA=${FOLDER}/dev-sents.json 69 | python 02.2_create_amr_json_nopar.py ${FILE_DATA} ${AMR_DATA} ${NUMBER_GRAPHS} 70 | 71 | AMR_DATA=${FOLDER}/test-sents.txt.amr 72 | FILE_DATA=${FOLDER}/test-sents.json 73 | python 02.2_create_amr_json_nopar.py ${FILE_DATA} ${AMR_DATA} ${NUMBER_GRAPHS} 74 | 75 | conda deactivate 76 | 77 | 78 | -------------------------------------------------------------------------------- /data/preprocess/requirements-preprocess.txt: -------------------------------------------------------------------------------- 1 | datasets==1.8.0 2 | dill==0.3.4 3 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl 4 | numpy==1.22.0 5 | packaging==20.9 6 | pandas==1.2.4 7 | Penman==1.2.0 8 | plac==1.1.3 9 | pluggy==0.13.1 10 | preshed==3.0.5 11 | protobuf==3.17.3 12 | py==1.10.0 13 | pyarrow==3.0.0 14 | pyparsing==2.4.7 15 | pytest==6.2.4 16 | python-dateutil==2.8.1 17 | pytz==2021.1 18 | requests==2.25.1 19 | s3transfer==0.4.2 20 | scipy==1.5.2 21 | six==1.16.0 22 | smatch==1.0.4 23 | spacy==3.0.6 24 | stanza==1.2 25 | torch==1.9.0 26 | sentence-transformers==1.2.0 27 | unidecode==1.2.0 28 | amrlib==0.7.1 29 | -------------------------------------------------------------------------------- /data/preprocess/update_envs_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | git clone https://github.com/clab/fast_align.git 4 | cd fast_align 5 | mkdir build 6 | cd build 7 | cmake .. 8 | make 9 | 10 | 11 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | 4 | 5 | def read_csv(file): 6 | # Read CSV file 7 | with open(file) as fp: 8 | reader = csv.reader(fp, delimiter=",", quotechar='"') 9 | # next(reader, None) # skip the headers 10 | data_read = [row for row in reader] 11 | 12 | return data_read 13 | 14 | def save_data(data, output_file): 15 | with open(output_file, "w", encoding="utf-8") as fd: 16 | for example in data: 17 | example = dict(example) 18 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 19 | 20 | 21 | def load_source_docs(file_path, to_dict=False): 22 | with open(file_path, encoding="utf-8") as f: 23 | data = [json.loads(line) for line in f] 24 | 25 | if to_dict: 26 | data = {example["id"]: example for example in data} 27 | return data 28 | 29 | 30 | def most_common(lst): 31 | return max(set(lst), key=lst.count) -------------------------------------------------------------------------------- /images/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/fact-graph/165285a483ebf2fc4ffe2202c65349b99dea9626/images/example.png -------------------------------------------------------------------------------- /images/factgraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/fact-graph/165285a483ebf2fc4ffe2202c65349b99dea9626/images/factgraph.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==1.7.0 2 | tokenizers==0.10.3 3 | unidecode==1.3.3 4 | sacremoses==0.0.47 5 | rdflib==6.1.1 -------------------------------------------------------------------------------- /src/download_pretrained_adapters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GRAPH_ADAPTERS=https://public.ukp.informatik.tu-darmstadt.de/ribeiro/factgraph/graph_adapters.bin 4 | TEXT_ADAPTERS=https://public.ukp.informatik.tu-darmstadt.de/ribeiro/factgraph/text_adapters.bin 5 | 6 | mkdir -p ../checkpoints 7 | wget ${GRAPH_ADAPTERS} -P ../checkpoints 8 | wget ${TEXT_ADAPTERS} -P ../checkpoints 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/download_trained_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GRAPH_ADAPTERS=https://public.ukp.informatik.tu-darmstadt.de/ribeiro/factgraph/factgraph.tar.gz 4 | TEXT_ADAPTERS=https://public.ukp.informatik.tu-darmstadt.de/ribeiro/factgraph/factgraph-edge.tar.gz 5 | 6 | mkdir -p ../checkpoints 7 | wget ${GRAPH_ADAPTERS} -P ../checkpoints 8 | wget ${TEXT_ADAPTERS} -P ../checkpoints 9 | 10 | 11 | tar zxvf ../checkpoints/factgraph.tar.gz -C ../checkpoints/ 12 | tar zxvf ../checkpoints/factgraph-edge.tar.gz -C ../checkpoints/ 13 | 14 | wget https://github.com/bjascob/amrlib-models/releases/download/parse_xfm_bart_large-v0_1_0/model_parse_xfm_bart_large-v0_1_0.tar.gz -P ../data/preprocess/ 15 | tar zxvf ../data/preprocess/model_parse_xfm_bart_large-v0_1_0.tar.gz -C ../data/preprocess/ -------------------------------------------------------------------------------- /src/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export OMP_NUM_THREADS=3 4 | ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | NUMBER_GRAPHS=5 6 | GPU_ID=$3 7 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 8 | export MODEL_NAME=google/electra-base-discriminator 9 | PREPROCESS_FOLDER=../data/preprocess 10 | 11 | MODEL_TYPE=$1 12 | JSON_FILE_VAL=$2 13 | JSON_FILE_VAL=${JSON_FILE_VAL} 14 | 15 | if [ "${MODEL_TYPE}" = "factgraph" ]; then 16 | PATH_MODEL='../checkpoints/factgraph' 17 | else 18 | PATH_MODEL='../checkpoints/factgraph-edge' 19 | fi 20 | 21 | 22 | source ~/anaconda3/etc/profile.d/conda.sh 23 | 24 | conda deactivate 25 | 26 | conda activate preprocess-fatcgraph 27 | python -u ${PREPROCESS_FOLDER}/preprocess_evaluate.py ${JSON_FILE_VAL} ${NUMBER_GRAPHS} 28 | conda deactivate 29 | 30 | conda activate factgraph 31 | python -u evaluate.py --model_type ${MODEL_TYPE} --model_dir ${PATH_MODEL} --model_name_or_path ${MODEL_NAME} \ 32 | --test_data_file ${JSON_FILE_VAL}.processed 33 | conda deactivate -------------------------------------------------------------------------------- /src/predict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export OMP_NUM_THREADS=3 4 | 5 | GPU_ID=$2 6 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 7 | export MODEL_NAME=google/electra-base-discriminator 8 | 9 | PATH_MODEL=$1 10 | FILE_VAL='../data/processed_dataset/test-sents-amr.json' 11 | 12 | CUDA_LAUNCH_BLOCKING=1 python -u main.py --test --save_dir ${PATH_MODEL} --model_name_or_path ${MODEL_NAME} \ 13 | --test_data_file ${FILE_VAL} -------------------------------------------------------------------------------- /src/predict_edgelevel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_ID=$2 4 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 5 | export MODEL_NAME=google/electra-base-discriminator 6 | export OMP_NUM_THREADS=3 7 | 8 | PATH_MODEL=$1 9 | FILE_VAL='../data/processed_dataset_edge_level/test-sents-amr.json' 10 | 11 | 12 | python -u main_edgelevel.py --test --model_name_or_path ${MODEL_NAME} \ 13 | --test_data_file ${FILE_VAL} \ 14 | --batch_size 8 \ 15 | --save_dir ${PATH_MODEL} 16 | 17 | -------------------------------------------------------------------------------- /src/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE_TRAIN='../data/processed_dataset/train-sents-amr.json' 4 | FILE_VAL='../data/processed_dataset/dev-sents-amr.json' 5 | FILE_TEST='../data/processed_dataset/test-sents-amr.json' 6 | 7 | GPU_ID=$1 8 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 9 | export MODEL_NAME=google/electra-base-discriminator 10 | export OMP_NUM_THREADS=3 11 | 12 | NAME_EXECUTION=$MODEL_NAME-$RANDOM 13 | PATH_MODEL=../checkpoints/${NAME_EXECUTION} 14 | PRE_MODEL_GRAPH_ADAPT=../checkpoints/graph_adapters.bin 15 | PRE_MODEL_TEXT_ADAPT=../checkpoints/text_adapters.bin 16 | rm -rf ${PATH_MODEL} 17 | mkdir -p ${PATH_MODEL} 18 | python -u main.py --model_name_or_path ${MODEL_NAME} \ 19 | --train_data_file ${FILE_TRAIN} \ 20 | --val_data_file ${FILE_VAL} \ 21 | --test_data_file ${FILE_TEST} \ 22 | --save_every_k_step 300 \ 23 | --batch_size 8 \ 24 | --adapter_size 32 \ 25 | --num_epoch 4 \ 26 | --pretrained_model_adapters ${PRE_MODEL_TEXT_ADAPT} \ 27 | --pretrained_model_graph_adapters ${PRE_MODEL_GRAPH_ADAPT} \ 28 | --save_dir ${PATH_MODEL} 29 | 30 | -------------------------------------------------------------------------------- /src/train_edgelevel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE_TRAIN='../data/processed_dataset_edge_level/train-sents-amr.json' 4 | FILE_VAL='../data/processed_dataset_edge_level/test-sents-amr.json' 5 | 6 | GPU_ID=$1 7 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 8 | export MODEL_NAME=google/electra-base-discriminator 9 | export OMP_NUM_THREADS=3 10 | 11 | NAME_EXECUTION=$MODEL_NAME-$RANDOM 12 | PATH_MODEL=../checkpoints/${NAME_EXECUTION} 13 | 14 | rm -rf ${PATH_MODEL} 15 | mkdir -p ${PATH_MODEL} 16 | python -u main_edgelevel.py --model_name_or_path ${MODEL_NAME} \ 17 | --train_data_file ${FILE_TRAIN} \ 18 | --val_data_file ${FILE_VAL} \ 19 | --save_every_k_step 100 \ 20 | --batch_size 8 \ 21 | --adapter_size 32 \ 22 | --num_epoch 2 \ 23 | --save_dir ${PATH_MODEL} 24 | 25 | -------------------------------------------------------------------------------- /src/transformers/activations_tf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import tensorflow as tf 18 | from packaging import version 19 | 20 | 21 | def _gelu(x): 22 | """ 23 | Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when 24 | initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 25 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see 26 | https://arxiv.org/abs/1606.08415 27 | """ 28 | x = tf.convert_to_tensor(x) 29 | cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) 30 | 31 | return x * cdf 32 | 33 | 34 | def _gelu_new(x): 35 | """ 36 | Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 37 | 38 | Args: 39 | x: float Tensor to perform activation 40 | 41 | Returns: 42 | `x` with the GELU activation applied. 43 | """ 44 | x = tf.convert_to_tensor(x) 45 | pi = tf.cast(math.pi, x.dtype) 46 | coeff = tf.cast(0.044715, x.dtype) 47 | cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) 48 | 49 | return x * cdf 50 | 51 | 52 | def mish(x): 53 | x = tf.convert_to_tensor(x) 54 | 55 | return x * tf.tanh(tf.math.softplus(x)) 56 | 57 | 58 | def gelu_fast(x): 59 | x = tf.convert_to_tensor(x) 60 | coeff1 = tf.cast(0.044715, x.dtype) 61 | coeff2 = tf.cast(0.7978845608, x.dtype) 62 | 63 | return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) 64 | 65 | 66 | if version.parse(tf.version.VERSION) >= version.parse("2.4"): 67 | 68 | def approximate_gelu_wrap(x): 69 | return tf.keras.activations.gelu(x, approximate=True) 70 | 71 | gelu = tf.keras.activations.gelu 72 | gelu_new = approximate_gelu_wrap 73 | else: 74 | gelu = _gelu 75 | gelu_new = _gelu_new 76 | 77 | 78 | ACT2FN = { 79 | "gelu": gelu, 80 | "relu": tf.keras.activations.relu, 81 | "swish": tf.keras.activations.swish, 82 | "silu": tf.keras.activations.swish, 83 | "gelu_new": gelu_new, 84 | "mish": mish, 85 | "tanh": tf.keras.activations.tanh, 86 | "gelu_fast": gelu_fast, 87 | } 88 | 89 | 90 | def get_tf_activation(activation_string): 91 | if activation_string in ACT2FN: 92 | return ACT2FN[activation_string] 93 | else: 94 | raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") 95 | -------------------------------------------------------------------------------- /src/transformers/benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/fact-graph/165285a483ebf2fc4ffe2202c65349b99dea9626/src/transformers/benchmark/__init__.py -------------------------------------------------------------------------------- /src/transformers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from argparse import ArgumentParser 17 | 18 | 19 | class BaseTransformersCLICommand(ABC): 20 | @staticmethod 21 | @abstractmethod 22 | def register_subcommand(parser: ArgumentParser): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def run(self): 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /src/transformers/commands/download.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from argparse import ArgumentParser 16 | 17 | from . import BaseTransformersCLICommand 18 | 19 | 20 | def download_command_factory(args): 21 | return DownloadCommand(args.model, args.cache_dir, args.force) 22 | 23 | 24 | class DownloadCommand(BaseTransformersCLICommand): 25 | @staticmethod 26 | def register_subcommand(parser: ArgumentParser): 27 | download_parser = parser.add_parser("download") 28 | download_parser.add_argument( 29 | "--cache-dir", type=str, default=None, help="Path to location to store the models" 30 | ) 31 | download_parser.add_argument( 32 | "--force", action="store_true", help="Force the model to be download even if already in cache-dir" 33 | ) 34 | download_parser.add_argument("model", type=str, help="Name of the model to download") 35 | download_parser.set_defaults(func=download_command_factory) 36 | 37 | def __init__(self, model: str, cache: str, force: bool): 38 | self._model = model 39 | self._cache = cache 40 | self._force = force 41 | 42 | def run(self): 43 | from ..models.auto import AutoModel, AutoTokenizer 44 | 45 | AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) 46 | AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) 47 | -------------------------------------------------------------------------------- /src/transformers/commands/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import platform 16 | from argparse import ArgumentParser 17 | 18 | from .. import __version__ as version 19 | from ..file_utils import is_tf_available, is_torch_available 20 | from . import BaseTransformersCLICommand 21 | 22 | 23 | def info_command_factory(_): 24 | return EnvironmentCommand() 25 | 26 | 27 | class EnvironmentCommand(BaseTransformersCLICommand): 28 | @staticmethod 29 | def register_subcommand(parser: ArgumentParser): 30 | download_parser = parser.add_parser("env") 31 | download_parser.set_defaults(func=info_command_factory) 32 | 33 | def run(self): 34 | pt_version = "not installed" 35 | pt_cuda_available = "NA" 36 | if is_torch_available(): 37 | import torch 38 | 39 | pt_version = torch.__version__ 40 | pt_cuda_available = torch.cuda.is_available() 41 | 42 | tf_version = "not installed" 43 | tf_cuda_available = "NA" 44 | if is_tf_available(): 45 | import tensorflow as tf 46 | 47 | tf_version = tf.__version__ 48 | try: 49 | # deprecated in v2.1 50 | tf_cuda_available = tf.test.is_gpu_available() 51 | except AttributeError: 52 | # returns list of devices, convert to bool 53 | tf_cuda_available = bool(tf.config.list_physical_devices("GPU")) 54 | 55 | info = { 56 | "`transformers` version": version, 57 | "Platform": platform.platform(), 58 | "Python version": platform.python_version(), 59 | "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", 60 | "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})", 61 | "Using GPU in script?": "", 62 | "Using distributed or parallel set-up in script?": "", 63 | } 64 | 65 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") 66 | print(self.format_dict(info)) 67 | 68 | return info 69 | 70 | @staticmethod 71 | def format_dict(d): 72 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" 73 | -------------------------------------------------------------------------------- /src/transformers/commands/transformers_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from argparse import ArgumentParser 17 | 18 | from .add_new_model import AddNewModelCommand 19 | from .convert import ConvertCommand 20 | from .download import DownloadCommand 21 | from .env import EnvironmentCommand 22 | from .lfs import LfsCommands 23 | from .run import RunCommand 24 | from .serving import ServeCommand 25 | from .user import UserCommands 26 | 27 | 28 | def main(): 29 | parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli []") 30 | commands_parser = parser.add_subparsers(help="transformers-cli command helpers") 31 | 32 | # Register commands 33 | ConvertCommand.register_subcommand(commands_parser) 34 | DownloadCommand.register_subcommand(commands_parser) 35 | EnvironmentCommand.register_subcommand(commands_parser) 36 | RunCommand.register_subcommand(commands_parser) 37 | ServeCommand.register_subcommand(commands_parser) 38 | UserCommands.register_subcommand(commands_parser) 39 | AddNewModelCommand.register_subcommand(commands_parser) 40 | LfsCommands.register_subcommand(commands_parser) 41 | 42 | # Let's go 43 | args = parser.parse_args() 44 | 45 | if not hasattr(args, "func"): 46 | parser.print_help() 47 | exit(1) 48 | 49 | # Run 50 | service = args.func(args) 51 | service.run() 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /src/transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .metrics import glue_compute_metrics, xnli_compute_metrics 20 | from .processors import ( 21 | DataProcessor, 22 | InputExample, 23 | InputFeatures, 24 | SingleSentenceClassificationProcessor, 25 | SquadExample, 26 | SquadFeatures, 27 | SquadV1Processor, 28 | SquadV2Processor, 29 | glue_convert_examples_to_features, 30 | glue_output_modes, 31 | glue_processors, 32 | glue_tasks_num_labels, 33 | squad_convert_examples_to_features, 34 | xnli_output_modes, 35 | xnli_processors, 36 | xnli_tasks_num_labels, 37 | ) 38 | -------------------------------------------------------------------------------- /src/transformers/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .glue import GlueDataset, GlueDataTrainingArguments 20 | from .language_modeling import ( 21 | LineByLineTextDataset, 22 | LineByLineWithRefDataset, 23 | LineByLineWithSOPTextDataset, 24 | TextDataset, 25 | TextDatasetForNextSentencePrediction, 26 | ) 27 | from .squad import SquadDataset, SquadDataTrainingArguments 28 | -------------------------------------------------------------------------------- /src/transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels 20 | from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features 21 | from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor 22 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels 23 | -------------------------------------------------------------------------------- /src/transformers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version, require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .file_utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | 45 | 46 | def dep_version_check(pkg, hint=None): 47 | require_version(deps[pkg], hint) 48 | -------------------------------------------------------------------------------- /src/transformers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow", 6 | "black": "black==21.4b0", 7 | "cookiecutter": "cookiecutter==1.7.2", 8 | "dataclasses": "dataclasses", 9 | "datasets": "datasets", 10 | "deepspeed": "deepspeed>=0.3.16", 11 | "docutils": "docutils==0.16.0", 12 | "fairscale": "fairscale>0.3", 13 | "faiss-cpu": "faiss-cpu", 14 | "fastapi": "fastapi", 15 | "filelock": "filelock", 16 | "flake8": "flake8>=3.8.3", 17 | "flax": "flax>=0.3.2", 18 | "fugashi": "fugashi>=1.0", 19 | "huggingface-hub": "huggingface-hub==0.0.8", 20 | "importlib_metadata": "importlib_metadata", 21 | "ipadic": "ipadic>=1.0.0,<2.0", 22 | "isort": "isort>=5.5.4", 23 | "jax": "jax>=0.2.8", 24 | "jaxlib": "jaxlib>=0.1.59", 25 | "jieba": "jieba", 26 | "keras2onnx": "keras2onnx", 27 | "nltk": "nltk", 28 | "numpy": "numpy>=1.17", 29 | "onnxconverter-common": "onnxconverter-common", 30 | "onnxruntime-tools": "onnxruntime-tools>=1.4.2", 31 | "onnxruntime": "onnxruntime>=1.4.0", 32 | "packaging": "packaging", 33 | "parameterized": "parameterized", 34 | "protobuf": "protobuf", 35 | "psutil": "psutil", 36 | "pydantic": "pydantic", 37 | "pytest": "pytest", 38 | "pytest-sugar": "pytest-sugar", 39 | "pytest-xdist": "pytest-xdist", 40 | "python": "python>=3.6.0", 41 | "recommonmark": "recommonmark", 42 | "regex": "regex!=2019.12.17", 43 | "requests": "requests", 44 | "rouge-score": "rouge-score", 45 | "sacrebleu": "sacrebleu>=1.4.12", 46 | "sacremoses": "sacremoses", 47 | "sagemaker": "sagemaker>=2.31.0", 48 | "scikit-learn": "scikit-learn", 49 | "sentencepiece": "sentencepiece==0.1.91", 50 | "soundfile": "soundfile", 51 | "sphinx-copybutton": "sphinx-copybutton", 52 | "sphinx-markdown-tables": "sphinx-markdown-tables", 53 | "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", 54 | "sphinx": "sphinx==3.2.1", 55 | "sphinxext-opengraph": "sphinxext-opengraph==0.4.1", 56 | "starlette": "starlette", 57 | "tensorflow-cpu": "tensorflow-cpu>=2.3", 58 | "tensorflow": "tensorflow>=2.3", 59 | "timeout-decorator": "timeout-decorator", 60 | "tokenizers": "tokenizers>=0.10.1,<0.11", 61 | "torch": "torch>=1.0", 62 | "torchaudio": "torchaudio", 63 | "tqdm": "tqdm>=4.27", 64 | "unidic": "unidic>=1.0.2", 65 | "unidic_lite": "unidic_lite>=1.0.7", 66 | "uvicorn": "uvicorn", 67 | } 68 | -------------------------------------------------------------------------------- /src/transformers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from . import ( 20 | albert, 21 | auto, 22 | bart, 23 | barthez, 24 | bert, 25 | bert_generation, 26 | bert_japanese, 27 | bertweet, 28 | big_bird, 29 | bigbird_pegasus, 30 | blenderbot, 31 | blenderbot_small, 32 | camembert, 33 | clip, 34 | convbert, 35 | cpm, 36 | ctrl, 37 | deberta, 38 | deit, 39 | dialogpt, 40 | distilbert, 41 | dpr, 42 | electra, 43 | electra_adapter, 44 | encoder_decoder, 45 | flaubert, 46 | fsmt, 47 | funnel, 48 | gpt2, 49 | gpt_neo, 50 | herbert, 51 | layoutlm, 52 | led, 53 | longformer, 54 | luke, 55 | lxmert, 56 | m2m_100, 57 | marian, 58 | mbart, 59 | megatron_bert, 60 | mmbt, 61 | mobilebert, 62 | mpnet, 63 | mt5, 64 | openai, 65 | pegasus, 66 | phobert, 67 | prophetnet, 68 | rag, 69 | reformer, 70 | retribert, 71 | roberta, 72 | speech_to_text, 73 | squeezebert, 74 | t5, 75 | tapas, 76 | transfo_xl, 77 | vit, 78 | wav2vec2, 79 | xlm, 80 | xlm_roberta, 81 | xlnet, 82 | ) 83 | -------------------------------------------------------------------------------- /src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ALBERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = AlbertConfig.from_json_file(albert_config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = AlbertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_albert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--albert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained ALBERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /src/transformers/models/barthez/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_tokenizers_available 22 | 23 | 24 | _import_structure = {} 25 | 26 | if is_sentencepiece_available(): 27 | _import_structure["tokenization_barthez"] = ["BarthezTokenizer"] 28 | 29 | if is_tokenizers_available(): 30 | _import_structure["tokenization_barthez_fast"] = ["BarthezTokenizerFast"] 31 | 32 | 33 | if TYPE_CHECKING: 34 | 35 | if is_sentencepiece_available(): 36 | from .tokenization_barthez import BarthezTokenizer 37 | 38 | if is_tokenizers_available(): 39 | from .tokenization_barthez_fast import BarthezTokenizerFast 40 | 41 | else: 42 | import importlib 43 | import os 44 | import sys 45 | 46 | class _LazyModule(_BaseLazyModule): 47 | """ 48 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 49 | """ 50 | 51 | __file__ = globals()["__file__"] 52 | __path__ = [os.path.dirname(__file__)] 53 | 54 | def _get_module(self, module_name: str): 55 | return importlib.import_module("." + module_name, self.__name__) 56 | 57 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 58 | -------------------------------------------------------------------------------- /src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--bert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained BERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /src/transformers/models/bert_adapter/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--bert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained BERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /src/transformers/models/bert_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_bert_generation": ["BertGenerationConfig"], 26 | } 27 | 28 | if is_sentencepiece_available(): 29 | _import_structure["tokenization_bert_generation"] = ["BertGenerationTokenizer"] 30 | 31 | if is_torch_available(): 32 | _import_structure["modeling_bert_generation"] = [ 33 | "BertGenerationDecoder", 34 | "BertGenerationEncoder", 35 | "load_tf_weights_in_bert_generation", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_bert_generation import BertGenerationConfig 41 | 42 | if is_sentencepiece_available(): 43 | from .tokenization_bert_generation import BertGenerationTokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_bert_generation import ( 47 | BertGenerationDecoder, 48 | BertGenerationEncoder, 49 | load_tf_weights_in_bert_generation, 50 | ) 51 | 52 | else: 53 | import importlib 54 | import os 55 | import sys 56 | 57 | class _LazyModule(_BaseLazyModule): 58 | """ 59 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 60 | """ 61 | 62 | __file__ = globals()["__file__"] 63 | __path__ = [os.path.dirname(__file__)] 64 | 65 | def _get_module(self, module_name: str): 66 | return importlib.import_module("." + module_name, self.__name__) 67 | 68 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 69 | -------------------------------------------------------------------------------- /src/transformers/models/bert_japanese/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /src/transformers/models/bertweet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_bertweet": ["BertweetTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_bertweet import BertweetTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BigBird checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird 21 | from transformers.utils import logging 22 | 23 | 24 | logging.set_verbosity_info() 25 | 26 | 27 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): 28 | # Initialise PyTorch model 29 | config = BigBirdConfig.from_json_file(big_bird_config_file) 30 | print(f"Building PyTorch model from configuration: {config}") 31 | 32 | if is_trivia_qa: 33 | model = BigBirdForQuestionAnswering(config) 34 | else: 35 | model = BigBirdForPreTraining(config) 36 | 37 | # Load weights from tf checkpoint 38 | load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) 39 | 40 | # Save pytorch-model 41 | print(f"Save PyTorch model to {pytorch_dump_path}") 42 | model.save_pretrained(pytorch_dump_path) 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | parser = argparse.ArgumentParser() 48 | # Required parameters 49 | parser.add_argument( 50 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 51 | ) 52 | parser.add_argument( 53 | "--big_bird_config_file", 54 | default=None, 55 | type=str, 56 | required=True, 57 | help="The config json file corresponding to the pre-trained BERT model. \n" 58 | "This specifies the model architecture.", 59 | ) 60 | parser.add_argument( 61 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 62 | ) 63 | parser.add_argument( 64 | "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head." 65 | ) 66 | args = parser.parse_args() 67 | convert_tf_checkpoint_to_pytorch( 68 | args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa 69 | ) 70 | -------------------------------------------------------------------------------- /src/transformers/models/bigbird_pegasus/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig"], 25 | } 26 | 27 | if is_torch_available(): 28 | _import_structure["modeling_bigbird_pegasus"] = [ 29 | "BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST", 30 | "BigBirdPegasusForCausalLM", 31 | "BigBirdPegasusForConditionalGeneration", 32 | "BigBirdPegasusForQuestionAnswering", 33 | "BigBirdPegasusForSequenceClassification", 34 | "BigBirdPegasusModel", 35 | "BigBirdPegasusPreTrainedModel", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig 41 | 42 | if is_torch_available(): 43 | from .modeling_bigbird_pegasus import ( 44 | BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST, 45 | BigBirdPegasusForCausalLM, 46 | BigBirdPegasusForConditionalGeneration, 47 | BigBirdPegasusForQuestionAnswering, 48 | BigBirdPegasusForSequenceClassification, 49 | BigBirdPegasusModel, 50 | BigBirdPegasusPreTrainedModel, 51 | ) 52 | 53 | 54 | else: 55 | import importlib 56 | import os 57 | import sys 58 | 59 | class _LazyModule(_BaseLazyModule): 60 | """ 61 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 62 | """ 63 | 64 | __file__ = globals()["__file__"] 65 | __path__ = [os.path.dirname(__file__)] 66 | 67 | def _get_module(self, module_name: str): 68 | return importlib.import_module("." + module_name, self.__name__) 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 71 | -------------------------------------------------------------------------------- /src/transformers/models/blenderbot/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig"], 26 | "tokenization_blenderbot": ["BlenderbotTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_blenderbot"] = [ 31 | "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "BlenderbotForCausalLM", 33 | "BlenderbotForConditionalGeneration", 34 | "BlenderbotModel", 35 | "BlenderbotPreTrainedModel", 36 | ] 37 | 38 | 39 | if is_tf_available(): 40 | _import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"] 41 | 42 | 43 | if TYPE_CHECKING: 44 | from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig 45 | from .tokenization_blenderbot import BlenderbotTokenizer 46 | 47 | if is_torch_available(): 48 | from .modeling_blenderbot import ( 49 | BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, 50 | BlenderbotForCausalLM, 51 | BlenderbotForConditionalGeneration, 52 | BlenderbotModel, 53 | BlenderbotPreTrainedModel, 54 | ) 55 | 56 | if is_tf_available(): 57 | from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel 58 | 59 | else: 60 | import importlib 61 | import os 62 | import sys 63 | 64 | class _LazyModule(_BaseLazyModule): 65 | """ 66 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 67 | """ 68 | 69 | __file__ = globals()["__file__"] 70 | __path__ = [os.path.dirname(__file__)] 71 | 72 | def _get_module(self, module_name: str): 73 | return importlib.import_module("." + module_name, self.__name__) 74 | 75 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 76 | -------------------------------------------------------------------------------- /src/transformers/models/blenderbot_small/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_blenderbot_small": ["BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotSmallConfig"], 25 | "tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_blenderbot_small"] = [ 30 | "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", 31 | "BlenderbotSmallForCausalLM", 32 | "BlenderbotSmallForConditionalGeneration", 33 | "BlenderbotSmallModel", 34 | "BlenderbotSmallPreTrainedModel", 35 | ] 36 | 37 | if is_tf_available(): 38 | _import_structure["modeling_tf_blenderbot_small"] = [ 39 | "TFBlenderbotSmallForConditionalGeneration", 40 | "TFBlenderbotSmallModel", 41 | ] 42 | 43 | if TYPE_CHECKING: 44 | from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig 45 | from .tokenization_blenderbot_small import BlenderbotSmallTokenizer 46 | 47 | if is_torch_available(): 48 | from .modeling_blenderbot_small import ( 49 | BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, 50 | BlenderbotSmallForCausalLM, 51 | BlenderbotSmallForConditionalGeneration, 52 | BlenderbotSmallModel, 53 | BlenderbotSmallPreTrainedModel, 54 | ) 55 | 56 | if is_tf_available(): 57 | from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel 58 | 59 | else: 60 | import importlib 61 | import os 62 | import sys 63 | 64 | class _LazyModule(_BaseLazyModule): 65 | """ 66 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 67 | """ 68 | 69 | __file__ = globals()["__file__"] 70 | __path__ = [os.path.dirname(__file__)] 71 | 72 | def _get_module(self, module_name: str): 73 | return importlib.import_module("." + module_name, self.__name__) 74 | 75 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 76 | -------------------------------------------------------------------------------- /src/transformers/models/camembert/configuration_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ CamemBERT configuration """ 17 | 18 | from ...utils import logging 19 | from ..roberta.configuration_roberta import RobertaConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "camembert-base": "https://huggingface.co/camembert-base/resolve/main/config.json", 26 | "umberto-commoncrawl-cased-v1": "https://huggingface.co/Musixmatch/umberto-commoncrawl-cased-v1/resolve/main/config.json", 27 | "umberto-wikipedia-uncased-v1": "https://huggingface.co/Musixmatch/umberto-wikipedia-uncased-v1/resolve/main/config.json", 28 | } 29 | 30 | 31 | class CamembertConfig(RobertaConfig): 32 | """ 33 | This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate 34 | documentation alongside usage examples. 35 | """ 36 | 37 | model_type = "camembert" 38 | -------------------------------------------------------------------------------- /src/transformers/models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available, is_vision_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_clip": ["CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", "CLIPConfig", "CLIPTextConfig", "CLIPVisionConfig"], 25 | "tokenization_clip": ["CLIPTokenizer"], 26 | } 27 | 28 | if is_tokenizers_available(): 29 | _import_structure["tokenization_clip_fast"] = ["CLIPTokenizerFast"] 30 | 31 | if is_vision_available(): 32 | _import_structure["feature_extraction_clip"] = ["CLIPFeatureExtractor"] 33 | _import_structure["processing_clip"] = ["CLIPProcessor"] 34 | 35 | if is_torch_available(): 36 | _import_structure["modeling_clip"] = [ 37 | "CLIP_PRETRAINED_MODEL_ARCHIVE_LIST", 38 | "CLIPModel", 39 | "CLIPPreTrainedModel", 40 | "CLIPTextModel", 41 | "CLIPVisionModel", 42 | ] 43 | 44 | 45 | if TYPE_CHECKING: 46 | from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig 47 | from .tokenization_clip import CLIPTokenizer 48 | 49 | if is_tokenizers_available(): 50 | from .tokenization_clip_fast import CLIPTokenizerFast 51 | 52 | if is_vision_available(): 53 | from .feature_extraction_clip import CLIPFeatureExtractor 54 | from .processing_clip import CLIPProcessor 55 | 56 | if is_torch_available(): 57 | from .modeling_clip import ( 58 | CLIP_PRETRAINED_MODEL_ARCHIVE_LIST, 59 | CLIPModel, 60 | CLIPPreTrainedModel, 61 | CLIPTextModel, 62 | CLIPVisionModel, 63 | ) 64 | 65 | 66 | else: 67 | import importlib 68 | import os 69 | import sys 70 | 71 | class _LazyModule(_BaseLazyModule): 72 | """ 73 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 74 | """ 75 | 76 | __file__ = globals()["__file__"] 77 | __path__ = [os.path.dirname(__file__)] 78 | 79 | def _get_module(self, module_name: str): 80 | return importlib.import_module("." + module_name, self.__name__) 81 | 82 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 83 | -------------------------------------------------------------------------------- /src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ConvBERT checkpoint.""" 16 | 17 | import argparse 18 | 19 | from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert 20 | from transformers.utils import logging 21 | 22 | 23 | logging.set_verbosity_info() 24 | 25 | 26 | def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): 27 | conf = ConvBertConfig.from_json_file(convbert_config_file) 28 | model = ConvBertModel(conf) 29 | 30 | model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) 31 | model.save_pretrained(pytorch_dump_path) 32 | 33 | tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) 34 | tf_model.save_pretrained(pytorch_dump_path) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | # Required parameters 40 | parser.add_argument( 41 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 42 | ) 43 | parser.add_argument( 44 | "--convbert_config_file", 45 | default=None, 46 | type=str, 47 | required=True, 48 | help="The config json file corresponding to the pre-trained ConvBERT model. \n" 49 | "This specifies the model architecture.", 50 | ) 51 | parser.add_argument( 52 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 53 | ) 54 | args = parser.parse_args() 55 | convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) 56 | -------------------------------------------------------------------------------- /src/transformers/models/convbert/tokenization_convbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for ConvBERT.""" 16 | from ...utils import logging 17 | from ..bert.tokenization_bert import BertTokenizer 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 23 | 24 | PRETRAINED_VOCAB_FILES_MAP = { 25 | "vocab_file": { 26 | "YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/vocab.txt", 27 | "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt", 28 | "YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/vocab.txt", 29 | } 30 | } 31 | 32 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 33 | "YituTech/conv-bert-base": 512, 34 | "YituTech/conv-bert-medium-small": 512, 35 | "YituTech/conv-bert-small": 512, 36 | } 37 | 38 | 39 | PRETRAINED_INIT_CONFIGURATION = { 40 | "YituTech/conv-bert-base": {"do_lower_case": True}, 41 | "YituTech/conv-bert-medium-small": {"do_lower_case": True}, 42 | "YituTech/conv-bert-small": {"do_lower_case": True}, 43 | } 44 | 45 | 46 | class ConvBertTokenizer(BertTokenizer): 47 | r""" 48 | Construct a ConvBERT tokenizer. :class:`~transformers.ConvBertTokenizer` is identical to 49 | :class:`~transformers.BertTokenizer` and runs end-to-end tokenization: punctuation splitting and wordpiece. Refer 50 | to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning parameters. 51 | """ 52 | 53 | vocab_files_names = VOCAB_FILES_NAMES 54 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 55 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 56 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 57 | -------------------------------------------------------------------------------- /src/transformers/models/convbert/tokenization_convbert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for ConvBERT.""" 16 | from ...utils import logging 17 | from ..bert.tokenization_bert_fast import BertTokenizerFast 18 | from .tokenization_convbert import ConvBertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/vocab.txt", 28 | "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt", 29 | "YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/vocab.txt", 30 | } 31 | } 32 | 33 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 34 | "YituTech/conv-bert-base": 512, 35 | "YituTech/conv-bert-medium-small": 512, 36 | "YituTech/conv-bert-small": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "YituTech/conv-bert-base": {"do_lower_case": True}, 42 | "YituTech/conv-bert-medium-small": {"do_lower_case": True}, 43 | "YituTech/conv-bert-small": {"do_lower_case": True}, 44 | } 45 | 46 | 47 | class ConvBertTokenizerFast(BertTokenizerFast): 48 | r""" 49 | Construct a "fast" ConvBERT tokenizer (backed by HuggingFace's `tokenizers` library). 50 | 51 | :class:`~transformers.ConvBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 52 | end-to-end tokenization: punctuation splitting and wordpiece. 53 | 54 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 55 | parameters. 56 | """ 57 | vocab_files_names = VOCAB_FILES_NAMES 58 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 59 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 60 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 61 | slow_tokenizer_class = ConvBertTokenizer 62 | -------------------------------------------------------------------------------- /src/transformers/models/cpm/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_cpm": ["CpmTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_cpm import CpmTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /src/transformers/models/ctrl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig"], 26 | "tokenization_ctrl": ["CTRLTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_ctrl"] = [ 31 | "CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "CTRLForSequenceClassification", 33 | "CTRLLMHeadModel", 34 | "CTRLModel", 35 | "CTRLPreTrainedModel", 36 | ] 37 | 38 | if is_tf_available(): 39 | _import_structure["modeling_tf_ctrl"] = [ 40 | "TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", 41 | "TFCTRLForSequenceClassification", 42 | "TFCTRLLMHeadModel", 43 | "TFCTRLModel", 44 | "TFCTRLPreTrainedModel", 45 | ] 46 | 47 | 48 | if TYPE_CHECKING: 49 | from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig 50 | from .tokenization_ctrl import CTRLTokenizer 51 | 52 | if is_torch_available(): 53 | from .modeling_ctrl import ( 54 | CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, 55 | CTRLForSequenceClassification, 56 | CTRLLMHeadModel, 57 | CTRLModel, 58 | CTRLPreTrainedModel, 59 | ) 60 | 61 | if is_tf_available(): 62 | from .modeling_tf_ctrl import ( 63 | TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, 64 | TFCTRLForSequenceClassification, 65 | TFCTRLLMHeadModel, 66 | TFCTRLModel, 67 | TFCTRLPreTrainedModel, 68 | ) 69 | 70 | else: 71 | import importlib 72 | import os 73 | import sys 74 | 75 | class _LazyModule(_BaseLazyModule): 76 | """ 77 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 78 | """ 79 | 80 | __file__ = globals()["__file__"] 81 | __path__ = [os.path.dirname(__file__)] 82 | 83 | def _get_module(self, module_name: str): 84 | return importlib.import_module("." + module_name, self.__name__) 85 | 86 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 87 | -------------------------------------------------------------------------------- /src/transformers/models/deberta/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig"], 26 | "tokenization_deberta": ["DebertaTokenizer"], 27 | } 28 | 29 | if is_tokenizers_available(): 30 | _import_structure["tokenization_deberta_fast"] = ["DebertaTokenizerFast"] 31 | 32 | if is_torch_available(): 33 | _import_structure["modeling_deberta"] = [ 34 | "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", 35 | "DebertaForMaskedLM", 36 | "DebertaForQuestionAnswering", 37 | "DebertaForSequenceClassification", 38 | "DebertaForTokenClassification", 39 | "DebertaModel", 40 | "DebertaPreTrainedModel", 41 | ] 42 | 43 | 44 | if TYPE_CHECKING: 45 | from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig 46 | from .tokenization_deberta import DebertaTokenizer 47 | 48 | if is_tokenizers_available(): 49 | from .tokenization_deberta_fast import DebertaTokenizerFast 50 | 51 | if is_torch_available(): 52 | from .modeling_deberta import ( 53 | DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, 54 | DebertaForMaskedLM, 55 | DebertaForQuestionAnswering, 56 | DebertaForSequenceClassification, 57 | DebertaForTokenClassification, 58 | DebertaModel, 59 | DebertaPreTrainedModel, 60 | ) 61 | 62 | else: 63 | import importlib 64 | import os 65 | import sys 66 | 67 | class _LazyModule(_BaseLazyModule): 68 | """ 69 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 70 | """ 71 | 72 | __file__ = globals()["__file__"] 73 | __path__ = [os.path.dirname(__file__)] 74 | 75 | def _get_module(self, module_name: str): 76 | return importlib.import_module("." + module_name, self.__name__) 77 | 78 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 79 | -------------------------------------------------------------------------------- /src/transformers/models/deberta_v2/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], 26 | "tokenization_deberta_v2": ["DebertaV2Tokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_deberta_v2"] = [ 31 | "DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "DebertaV2ForMaskedLM", 33 | "DebertaV2ForQuestionAnswering", 34 | "DebertaV2ForSequenceClassification", 35 | "DebertaV2ForTokenClassification", 36 | "DebertaV2Model", 37 | "DebertaV2PreTrainedModel", 38 | ] 39 | 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config 43 | from .tokenization_deberta_v2 import DebertaV2Tokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_deberta_v2 import ( 47 | DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | DebertaV2ForMaskedLM, 49 | DebertaV2ForQuestionAnswering, 50 | DebertaV2ForSequenceClassification, 51 | DebertaV2ForTokenClassification, 52 | DebertaV2Model, 53 | DebertaV2PreTrainedModel, 54 | ) 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /src/transformers/models/deit/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"], 25 | } 26 | 27 | if is_vision_available(): 28 | _import_structure["feature_extraction_deit"] = ["DeiTFeatureExtractor"] 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_deit"] = [ 32 | "DEIT_PRETRAINED_MODEL_ARCHIVE_LIST", 33 | "DeiTForImageClassification", 34 | "DeiTForImageClassificationWithTeacher", 35 | "DeiTModel", 36 | "DeiTPreTrainedModel", 37 | ] 38 | 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig 42 | 43 | if is_vision_available(): 44 | from .feature_extraction_deit import DeiTFeatureExtractor 45 | 46 | if is_torch_available(): 47 | from .modeling_deit import ( 48 | DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, 49 | DeiTForImageClassification, 50 | DeiTForImageClassificationWithTeacher, 51 | DeiTModel, 52 | DeiTPreTrainedModel, 53 | ) 54 | 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /src/transformers/models/dialogpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/fact-graph/165285a483ebf2fc4ffe2202c65349b99dea9626/src/transformers/models/dialogpt/__init__.py -------------------------------------------------------------------------------- /src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import torch 19 | 20 | from transformers.file_utils import WEIGHTS_NAME 21 | 22 | 23 | DIALOGPT_MODELS = ["small", "medium", "large"] 24 | 25 | OLD_KEY = "lm_head.decoder.weight" 26 | NEW_KEY = "lm_head.weight" 27 | 28 | 29 | def convert_dialogpt_checkpoint(checkpoint_path: str, pytorch_dump_folder_path: str): 30 | d = torch.load(checkpoint_path) 31 | d[NEW_KEY] = d.pop(OLD_KEY) 32 | os.makedirs(pytorch_dump_folder_path, exist_ok=True) 33 | torch.save(d, os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--dialogpt_path", default=".", type=str) 39 | args = parser.parse_args() 40 | for MODEL in DIALOGPT_MODELS: 41 | checkpoint_path = os.path.join(args.dialogpt_path, f"{MODEL}_ft.pkl") 42 | pytorch_dump_folder_path = f"./DialoGPT-{MODEL}" 43 | convert_dialogpt_checkpoint( 44 | checkpoint_path, 45 | pytorch_dump_folder_path, 46 | ) 47 | -------------------------------------------------------------------------------- /src/transformers/models/electra_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import ( 22 | _BaseLazyModule, 23 | is_flax_available, 24 | is_tf_available, 25 | is_tokenizers_available, 26 | is_torch_available, 27 | ) 28 | 29 | 30 | _import_structure = { 31 | "configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig"], 32 | "tokenization_electra": ["ElectraTokenizer"], 33 | } 34 | 35 | if is_tokenizers_available(): 36 | _import_structure["tokenization_electra_fast"] = ["ElectraTokenizerFast"] 37 | 38 | if is_torch_available(): 39 | _import_structure["modeling_electra"] = [ 40 | "ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST", 41 | "ElectraAdapterModel", 42 | "ElectraPreTrainedModel", 43 | "ElectraAdapterForMaskedLM", 44 | "load_tf_weights_in_electra", 45 | ] 46 | 47 | if TYPE_CHECKING: 48 | 49 | if is_torch_available(): 50 | from .modeling_electra import ( 51 | ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, 52 | ElectraAdapterModel, 53 | ElectraPreTrainedModel, 54 | ElectraAdapterForMaskedLM, 55 | ) 56 | 57 | else: 58 | import importlib 59 | import os 60 | import sys 61 | 62 | class _LazyModule(_BaseLazyModule): 63 | """ 64 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 65 | """ 66 | 67 | __file__ = globals()["__file__"] 68 | __path__ = [os.path.dirname(__file__)] 69 | 70 | def _get_module(self, module_name: str): 71 | return importlib.import_module("." + module_name, self.__name__) 72 | 73 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 74 | -------------------------------------------------------------------------------- /src/transformers/models/encoder_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_encoder_decoder": ["EncoderDecoderConfig"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"] 30 | 31 | 32 | if TYPE_CHECKING: 33 | from .configuration_encoder_decoder import EncoderDecoderConfig 34 | 35 | if is_torch_available(): 36 | from .modeling_encoder_decoder import EncoderDecoderModel 37 | 38 | else: 39 | import importlib 40 | import os 41 | import sys 42 | 43 | class _LazyModule(_BaseLazyModule): 44 | """ 45 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 46 | """ 47 | 48 | __file__ = globals()["__file__"] 49 | __path__ = [os.path.dirname(__file__)] 50 | 51 | def _get_module(self, module_name: str): 52 | return importlib.import_module("." + module_name, self.__name__) 53 | 54 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 55 | -------------------------------------------------------------------------------- /src/transformers/models/fsmt/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig"], 26 | "tokenization_fsmt": ["FSMTTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_fsmt"] = ["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"] 31 | 32 | 33 | if TYPE_CHECKING: 34 | from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig 35 | from .tokenization_fsmt import FSMTTokenizer 36 | 37 | if is_torch_available(): 38 | from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel 39 | 40 | else: 41 | import importlib 42 | import os 43 | import sys 44 | 45 | class _LazyModule(_BaseLazyModule): 46 | """ 47 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 48 | """ 49 | 50 | __file__ = globals()["__file__"] 51 | __path__ = [os.path.dirname(__file__)] 52 | 53 | def _get_module(self, module_name: str): 54 | return importlib.import_module("." + module_name, self.__name__) 55 | 56 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 57 | -------------------------------------------------------------------------------- /src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Funnel checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import FunnelBaseModel, FunnelConfig, FunnelModel, load_tf_weights_in_funnel 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model): 30 | # Initialise PyTorch model 31 | config = FunnelConfig.from_json_file(config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = FunnelBaseModel(config) if base_model else FunnelModel(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_funnel(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | parser.add_argument( 61 | "--base_model", action="store_true", help="Whether you want just the base model (no decoder) or not." 62 | ) 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch( 65 | args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.base_model 66 | ) 67 | -------------------------------------------------------------------------------- /src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 23 | from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME 24 | from transformers.utils import logging 25 | 26 | 27 | logging.set_verbosity_info() 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config.from_json_file(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 44 | print(f"Save PyTorch model to {pytorch_weights_dump_path}") 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print(f"Save configuration file to {pytorch_config_dump_path}") 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # Required parameters 54 | parser.add_argument( 55 | "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | parser.add_argument( 61 | "--gpt2_config_file", 62 | default="", 63 | type=str, 64 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 65 | "This specifies the model architecture.", 66 | ) 67 | args = parser.parse_args() 68 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path) 69 | -------------------------------------------------------------------------------- /src/transformers/models/gpt_neo/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], 25 | } 26 | 27 | if is_torch_available(): 28 | _import_structure["modeling_gpt_neo"] = [ 29 | "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", 30 | "GPTNeoForCausalLM", 31 | "GPTNeoModel", 32 | "GPTNeoPreTrainedModel", 33 | "load_tf_weights_in_gpt_neo", 34 | ] 35 | 36 | 37 | if TYPE_CHECKING: 38 | from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig 39 | 40 | if is_torch_available(): 41 | from .modeling_gpt_neo import ( 42 | GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, 43 | GPTNeoForCausalLM, 44 | GPTNeoModel, 45 | GPTNeoPreTrainedModel, 46 | load_tf_weights_in_gpt_neo, 47 | ) 48 | 49 | 50 | else: 51 | import importlib 52 | import os 53 | import sys 54 | 55 | class _LazyModule(_BaseLazyModule): 56 | """ 57 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 58 | """ 59 | 60 | __file__ = globals()["__file__"] 61 | __path__ = [os.path.dirname(__file__)] 62 | 63 | def _get_module(self, module_name: str): 64 | return importlib.import_module("." + module_name, self.__name__) 65 | 66 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 67 | -------------------------------------------------------------------------------- /src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Eleuther AI and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert GPT Neo checkpoint.""" 16 | 17 | 18 | import argparse 19 | import json 20 | 21 | from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo 22 | from transformers.utils import logging 23 | 24 | 25 | logging.set_verbosity_info() 26 | 27 | 28 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 29 | # Initialise PyTorch model 30 | config_json = json.load(open(config_file, "r")) 31 | config = GPTNeoConfig( 32 | hidden_size=config_json["n_embd"], 33 | num_layers=config_json["n_layer"], 34 | num_heads=config_json["n_head"], 35 | attention_types=config_json["attention_types"], 36 | max_position_embeddings=config_json["n_ctx"], 37 | resid_dropout=config_json["res_dropout"], 38 | embed_dropout=config_json["embed_dropout"], 39 | attention_dropout=config_json["attn_dropout"], 40 | ) 41 | print(f"Building PyTorch model from configuration: {config}") 42 | model = GPTNeoForCausalLM(config) 43 | 44 | # Load weights from tf checkpoint 45 | load_tf_weights_in_gpt_neo(model, config, tf_checkpoint_path) 46 | 47 | # Save pytorch-model 48 | print(f"Save PyTorch model to {pytorch_dump_path}") 49 | model.save_pretrained(pytorch_dump_path) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | # Required parameters 55 | parser.add_argument( 56 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 57 | ) 58 | parser.add_argument( 59 | "--config_file", 60 | default=None, 61 | type=str, 62 | required=True, 63 | help="The config json file corresponding to the pre-trained mesh-tf model. \n" 64 | "This specifies the model architecture.", 65 | ) 66 | parser.add_argument( 67 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 68 | ) 69 | args = parser.parse_args() 70 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 71 | -------------------------------------------------------------------------------- /src/transformers/models/herbert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tokenizers_available 22 | 23 | 24 | _import_structure = { 25 | "tokenization_herbert": ["HerbertTokenizer"], 26 | } 27 | 28 | if is_tokenizers_available(): 29 | _import_structure["tokenization_herbert_fast"] = ["HerbertTokenizerFast"] 30 | 31 | 32 | if TYPE_CHECKING: 33 | from .tokenization_herbert import HerbertTokenizer 34 | 35 | if is_tokenizers_available(): 36 | from .tokenization_herbert_fast import HerbertTokenizerFast 37 | 38 | else: 39 | import importlib 40 | import os 41 | import sys 42 | 43 | class _LazyModule(_BaseLazyModule): 44 | """ 45 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 46 | """ 47 | 48 | __file__ = globals()["__file__"] 49 | __path__ = [os.path.dirname(__file__)] 50 | 51 | def _get_module(self, module_name: str): 52 | return importlib.import_module("." + module_name, self.__name__) 53 | 54 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 55 | -------------------------------------------------------------------------------- /src/transformers/models/ibert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_ibert"] = [ 30 | "IBERT_PRETRAINED_MODEL_ARCHIVE_LIST", 31 | "IBertForMaskedLM", 32 | "IBertForMultipleChoice", 33 | "IBertForQuestionAnswering", 34 | "IBertForSequenceClassification", 35 | "IBertForTokenClassification", 36 | "IBertModel", 37 | "IBertPreTrainedModel", 38 | ] 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig 42 | 43 | if is_torch_available(): 44 | from .modeling_ibert import ( 45 | IBERT_PRETRAINED_MODEL_ARCHIVE_LIST, 46 | IBertForMaskedLM, 47 | IBertForMultipleChoice, 48 | IBertForQuestionAnswering, 49 | IBertForSequenceClassification, 50 | IBertForTokenClassification, 51 | IBertModel, 52 | IBertPreTrainedModel, 53 | ) 54 | 55 | else: 56 | import importlib 57 | import os 58 | import sys 59 | 60 | class _LazyModule(_BaseLazyModule): 61 | """ 62 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 63 | """ 64 | 65 | __file__ = globals()["__file__"] 66 | __path__ = [os.path.dirname(__file__)] 67 | 68 | def _get_module(self, module_name: str): 69 | return importlib.import_module("." + module_name, self.__name__) 70 | 71 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 72 | -------------------------------------------------------------------------------- /src/transformers/models/layoutlm/tokenization_layoutlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization class for model LayoutLM.""" 16 | 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert import BertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt", 29 | "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt", 30 | } 31 | } 32 | 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 35 | "microsoft/layoutlm-base-uncased": 512, 36 | "microsoft/layoutlm-large-uncased": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "microsoft/layoutlm-base-uncased": {"do_lower_case": True}, 42 | "microsoft/layoutlm-large-uncased": {"do_lower_case": True}, 43 | } 44 | 45 | 46 | class LayoutLMTokenizer(BertTokenizer): 47 | r""" 48 | Constructs a LayoutLM tokenizer. 49 | 50 | :class:`~transformers.LayoutLMTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 51 | tokenization: punctuation splitting + wordpiece. 52 | 53 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 54 | parameters. 55 | """ 56 | 57 | vocab_files_names = VOCAB_FILES_NAMES 58 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 59 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 60 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 61 | -------------------------------------------------------------------------------- /src/transformers/models/layoutlm/tokenization_layoutlm_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization class for model LayoutLM.""" 16 | 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert_fast import BertTokenizerFast 20 | from .tokenization_layoutlm import LayoutLMTokenizer 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": { 29 | "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt", 30 | "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt", 31 | }, 32 | "tokenizer_file": { 33 | "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/tokenizer.json", 34 | "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/tokenizer.json", 35 | }, 36 | } 37 | 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | "microsoft/layoutlm-base-uncased": 512, 41 | "microsoft/layoutlm-large-uncased": 512, 42 | } 43 | 44 | 45 | PRETRAINED_INIT_CONFIGURATION = { 46 | "microsoft/layoutlm-base-uncased": {"do_lower_case": True}, 47 | "microsoft/layoutlm-large-uncased": {"do_lower_case": True}, 48 | } 49 | 50 | 51 | class LayoutLMTokenizerFast(BertTokenizerFast): 52 | r""" 53 | Constructs a "Fast" LayoutLMTokenizer. 54 | 55 | :class:`~transformers.LayoutLMTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 56 | end-to-end tokenization: punctuation splitting + wordpiece. 57 | 58 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 59 | parameters. 60 | """ 61 | 62 | vocab_files_names = VOCAB_FILES_NAMES 63 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 64 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 65 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 66 | slow_tokenizer_class = LayoutLMTokenizer 67 | -------------------------------------------------------------------------------- /src/transformers/models/led/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig"], 25 | "tokenization_led": ["LEDTokenizer"], 26 | } 27 | 28 | if is_tokenizers_available(): 29 | _import_structure["tokenization_led_fast"] = ["LEDTokenizerFast"] 30 | 31 | if is_torch_available(): 32 | _import_structure["modeling_led"] = [ 33 | "LED_PRETRAINED_MODEL_ARCHIVE_LIST", 34 | "LEDForConditionalGeneration", 35 | "LEDForQuestionAnswering", 36 | "LEDForSequenceClassification", 37 | "LEDModel", 38 | "LEDPreTrainedModel", 39 | ] 40 | 41 | 42 | if is_tf_available(): 43 | _import_structure["modeling_tf_led"] = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"] 44 | 45 | 46 | if TYPE_CHECKING: 47 | from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig 48 | from .tokenization_led import LEDTokenizer 49 | 50 | if is_tokenizers_available(): 51 | from .tokenization_led_fast import LEDTokenizerFast 52 | 53 | if is_torch_available(): 54 | from .modeling_led import ( 55 | LED_PRETRAINED_MODEL_ARCHIVE_LIST, 56 | LEDForConditionalGeneration, 57 | LEDForQuestionAnswering, 58 | LEDForSequenceClassification, 59 | LEDModel, 60 | LEDPreTrainedModel, 61 | ) 62 | 63 | if is_tf_available(): 64 | from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel 65 | 66 | else: 67 | import importlib 68 | import os 69 | import sys 70 | 71 | class _LazyModule(_BaseLazyModule): 72 | """ 73 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 74 | """ 75 | 76 | __file__ = globals()["__file__"] 77 | __path__ = [os.path.dirname(__file__)] 78 | 79 | def _get_module(self, module_name: str): 80 | return importlib.import_module("." + module_name, self.__name__) 81 | 82 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 83 | -------------------------------------------------------------------------------- /src/transformers/models/led/tokenization_led.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for LED.""" 16 | from ...utils import logging 17 | from ..bart.tokenization_bart import BartTokenizer 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | PRETRAINED_VOCAB_FILES_MAP = { 23 | "vocab_file": { 24 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json", 25 | }, 26 | "merges_file": { 27 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt", 28 | }, 29 | "tokenizer_file": { 30 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json", 31 | }, 32 | } 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 35 | "allenai/led-base-16384": 16384, 36 | } 37 | 38 | 39 | class LEDTokenizer(BartTokenizer): 40 | """ 41 | Construct a LED tokenizer. 42 | 43 | :class:`~transformers.LEDTokenizer` is identical to :class:`~transformers.BartTokenizer` and runs end-to-end 44 | tokenization: punctuation splitting and wordpiece. 45 | 46 | Refer to superclass :class:`~transformers.BartTokenizer` for usage examples and documentation concerning 47 | parameters. 48 | """ 49 | 50 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 51 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 52 | -------------------------------------------------------------------------------- /src/transformers/models/led/tokenization_led_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for LED.""" 16 | from ...utils import logging 17 | from ..bart.tokenization_bart_fast import BartTokenizerFast 18 | from .tokenization_led import LEDTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | PRETRAINED_VOCAB_FILES_MAP = { 24 | "vocab_file": { 25 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json", 26 | }, 27 | "merges_file": { 28 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt", 29 | }, 30 | "tokenizer_file": { 31 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json", 32 | }, 33 | } 34 | 35 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 36 | "allenai/led-base-16384": 16384, 37 | } 38 | 39 | 40 | class LEDTokenizerFast(BartTokenizerFast): 41 | r""" 42 | Construct a "fast" LED tokenizer (backed by HuggingFace's `tokenizers` library). 43 | 44 | :class:`~transformers.LEDTokenizerFast` is identical to :class:`~transformers.BartTokenizerFast` and runs 45 | end-to-end tokenization: punctuation splitting and wordpiece. 46 | 47 | Refer to superclass :class:`~transformers.BartTokenizerFast` for usage examples and documentation concerning 48 | parameters. 49 | """ 50 | 51 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 52 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 53 | slow_tokenizer_class = LEDTokenizer 54 | -------------------------------------------------------------------------------- /src/transformers/models/luke/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig"], 26 | "tokenization_luke": ["LukeTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_luke"] = [ 31 | "LUKE_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "LukeForEntityClassification", 33 | "LukeForEntityPairClassification", 34 | "LukeForEntitySpanClassification", 35 | "LukeModel", 36 | "LukePreTrainedModel", 37 | ] 38 | 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig 42 | from .tokenization_luke import LukeTokenizer 43 | 44 | if is_torch_available(): 45 | from .modeling_luke import ( 46 | LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, 47 | LukeForEntityClassification, 48 | LukeForEntityPairClassification, 49 | LukeForEntitySpanClassification, 50 | LukeModel, 51 | LukePreTrainedModel, 52 | ) 53 | 54 | else: 55 | import importlib 56 | import os 57 | import sys 58 | 59 | class _LazyModule(_BaseLazyModule): 60 | """ 61 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 62 | """ 63 | 64 | __file__ = globals()["__file__"] 65 | __path__ = [os.path.dirname(__file__)] 66 | 67 | def _get_module(self, module_name: str): 68 | return importlib.import_module("." + module_name, self.__name__) 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 71 | -------------------------------------------------------------------------------- /src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert LXMERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = LxmertConfig.from_json_file(config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = LxmertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_lxmert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /src/transformers/models/lxmert/tokenization_lxmert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ..bert.tokenization_bert import BertTokenizer 17 | 18 | 19 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 20 | 21 | PRETRAINED_VOCAB_FILES_MAP = { 22 | "vocab_file": { 23 | "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/vocab.txt", 24 | } 25 | } 26 | 27 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 28 | "unc-nlp/lxmert-base-uncased": 512, 29 | } 30 | 31 | PRETRAINED_INIT_CONFIGURATION = { 32 | "unc-nlp/lxmert-base-uncased": {"do_lower_case": True}, 33 | } 34 | 35 | 36 | class LxmertTokenizer(BertTokenizer): 37 | r""" 38 | Construct an LXMERT tokenizer. 39 | 40 | :class:`~transformers.LxmertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 41 | tokenization: punctuation splitting and wordpiece. 42 | 43 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 44 | parameters. 45 | """ 46 | 47 | vocab_files_names = VOCAB_FILES_NAMES 48 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 49 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 50 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 51 | -------------------------------------------------------------------------------- /src/transformers/models/lxmert/tokenization_lxmert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ..bert.tokenization_bert_fast import BertTokenizerFast 17 | from .tokenization_lxmert import LxmertTokenizer 18 | 19 | 20 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 21 | 22 | PRETRAINED_VOCAB_FILES_MAP = { 23 | "vocab_file": { 24 | "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/vocab.txt", 25 | }, 26 | "tokenizer_file": { 27 | "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/tokenizer.json", 28 | }, 29 | } 30 | 31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 32 | "unc-nlp/lxmert-base-uncased": 512, 33 | } 34 | 35 | PRETRAINED_INIT_CONFIGURATION = { 36 | "unc-nlp/lxmert-base-uncased": {"do_lower_case": True}, 37 | } 38 | 39 | 40 | class LxmertTokenizerFast(BertTokenizerFast): 41 | r""" 42 | Construct a "fast" LXMERT tokenizer (backed by HuggingFace's `tokenizers` library). 43 | 44 | :class:`~transformers.LxmertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 45 | end-to-end tokenization: punctuation splitting and wordpiece. 46 | 47 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 48 | parameters. 49 | """ 50 | vocab_files_names = VOCAB_FILES_NAMES 51 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 52 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 53 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 54 | slow_tokenizer_class = LxmertTokenizer 55 | -------------------------------------------------------------------------------- /src/transformers/models/m2m_100/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], 25 | "tokenization_m2m_100": ["M2M100Tokenizer"], 26 | } 27 | 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_m2m_100"] = [ 31 | "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "M2M100ForConditionalGeneration", 33 | "M2M100Model", 34 | "M2M100PreTrainedModel", 35 | ] 36 | 37 | 38 | if TYPE_CHECKING: 39 | from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config 40 | from .tokenization_m2m_100 import M2M100Tokenizer 41 | 42 | if is_torch_available(): 43 | from .modeling_m2m_100 import ( 44 | M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, 45 | M2M100ForConditionalGeneration, 46 | M2M100Model, 47 | M2M100PreTrainedModel, 48 | ) 49 | 50 | 51 | else: 52 | import importlib 53 | import os 54 | import sys 55 | 56 | class _LazyModule(_BaseLazyModule): 57 | """ 58 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 59 | """ 60 | 61 | __file__ = globals()["__file__"] 62 | __path__ = [os.path.dirname(__file__)] 63 | 64 | def _get_module(self, module_name: str): 65 | return importlib.import_module("." + module_name, self.__name__) 66 | 67 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 68 | -------------------------------------------------------------------------------- /src/transformers/models/marian/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import ( 21 | _BaseLazyModule, 22 | is_sentencepiece_available, 23 | is_tf_available, 24 | is_tokenizers_available, 25 | is_torch_available, 26 | ) 27 | 28 | 29 | _import_structure = { 30 | "configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig"], 31 | } 32 | 33 | if is_sentencepiece_available(): 34 | _import_structure["tokenization_marian"] = ["MarianTokenizer"] 35 | 36 | if is_torch_available(): 37 | _import_structure["modeling_marian"] = [ 38 | "MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST", 39 | "MarianForCausalLM", 40 | "MarianModel", 41 | "MarianMTModel", 42 | "MarianPreTrainedModel", 43 | ] 44 | 45 | if is_tf_available(): 46 | _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] 47 | 48 | 49 | if TYPE_CHECKING: 50 | from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig 51 | 52 | if is_sentencepiece_available(): 53 | from .tokenization_marian import MarianTokenizer 54 | 55 | if is_torch_available(): 56 | from .modeling_marian import ( 57 | MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST, 58 | MarianForCausalLM, 59 | MarianModel, 60 | MarianMTModel, 61 | MarianPreTrainedModel, 62 | ) 63 | 64 | if is_tf_available(): 65 | from .modeling_tf_marian import TFMarianModel, TFMarianMTModel 66 | 67 | else: 68 | import importlib 69 | import os 70 | import sys 71 | 72 | class _LazyModule(_BaseLazyModule): 73 | """ 74 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 75 | """ 76 | 77 | __file__ = globals()["__file__"] 78 | __path__ = [os.path.dirname(__file__)] 79 | 80 | def _get_module(self, module_name: str): 81 | return importlib.import_module("." + module_name, self.__name__) 82 | 83 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 84 | -------------------------------------------------------------------------------- /src/transformers/models/megatron_bert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], 25 | } 26 | 27 | if is_torch_available(): 28 | _import_structure["modeling_megatron_bert"] = [ 29 | "MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST", 30 | "MegatronBertForCausalLM", 31 | "MegatronBertForMaskedLM", 32 | "MegatronBertForMultipleChoice", 33 | "MegatronBertForNextSentencePrediction", 34 | "MegatronBertForPreTraining", 35 | "MegatronBertForQuestionAnswering", 36 | "MegatronBertForSequenceClassification", 37 | "MegatronBertForTokenClassification", 38 | "MegatronBertModel", 39 | ] 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig 43 | 44 | if is_torch_available(): 45 | from .modeling_megatron_bert import ( 46 | MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, 47 | MegatronBertForCausalLM, 48 | MegatronBertForMaskedLM, 49 | MegatronBertForMultipleChoice, 50 | MegatronBertForNextSentencePrediction, 51 | MegatronBertForPreTraining, 52 | MegatronBertForQuestionAnswering, 53 | MegatronBertForSequenceClassification, 54 | MegatronBertForTokenClassification, 55 | MegatronBertModel, 56 | ) 57 | 58 | else: 59 | import importlib 60 | import os 61 | import sys 62 | 63 | class _LazyModule(_BaseLazyModule): 64 | """ 65 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 66 | """ 67 | 68 | __file__ = globals()["__file__"] 69 | __path__ = [os.path.dirname(__file__)] 70 | 71 | def _get_module(self, module_name: str): 72 | return importlib.import_module("." + module_name, self.__name__) 73 | 74 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 75 | -------------------------------------------------------------------------------- /src/transformers/models/mmbt/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_mmbt": ["MMBTConfig"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_mmbt"] = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"] 30 | 31 | 32 | if TYPE_CHECKING: 33 | from .configuration_mmbt import MMBTConfig 34 | 35 | if is_torch_available(): 36 | from .modeling_mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings 37 | 38 | else: 39 | import importlib 40 | import os 41 | import sys 42 | 43 | class _LazyModule(_BaseLazyModule): 44 | """ 45 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 46 | """ 47 | 48 | __file__ = globals()["__file__"] 49 | __path__ = [os.path.dirname(__file__)] 50 | 51 | def _get_module(self, module_name: str): 52 | return importlib.import_module("." + module_name, self.__name__) 53 | 54 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 55 | -------------------------------------------------------------------------------- /src/transformers/models/mmbt/configuration_mmbt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # Copyright (c) HuggingFace Inc. team. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ MMBT configuration """ 17 | 18 | from ...utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | class MMBTConfig(object): 25 | """ 26 | This is the configuration class to store the configuration of a :class:`~transformers.MMBTModel`. It is used to 27 | instantiate a MMBT model according to the specified arguments, defining the model architecture. 28 | 29 | Args: 30 | config (:class:`~transformers.PreTrainedConfig`): 31 | Config of the underlying Transformer models. Its values are copied over to use a single config. 32 | num_labels (:obj:`int`, `optional`): 33 | Size of final Linear layer for classification. 34 | modal_hidden_size (:obj:`int`, `optional`, defaults to 2048): 35 | Embedding dimension of the non-text modality encoder. 36 | """ 37 | 38 | def __init__(self, config, num_labels=None, modal_hidden_size=2048): 39 | self.__dict__ = config.__dict__ 40 | self.modal_hidden_size = modal_hidden_size 41 | if num_labels: 42 | self.num_labels = num_labels 43 | -------------------------------------------------------------------------------- /src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | 17 | import torch 18 | 19 | from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert 20 | from transformers.utils import logging 21 | 22 | 23 | logging.set_verbosity_info() 24 | 25 | 26 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): 27 | # Initialise PyTorch model 28 | config = MobileBertConfig.from_json_file(mobilebert_config_file) 29 | print(f"Building PyTorch model from configuration: {config}") 30 | model = MobileBertForPreTraining(config) 31 | # Load weights from tf checkpoint 32 | model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path) 33 | # Save pytorch-model 34 | print(f"Save PyTorch model to {pytorch_dump_path}") 35 | torch.save(model.state_dict(), pytorch_dump_path) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | # Required parameters 41 | parser.add_argument( 42 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 43 | ) 44 | parser.add_argument( 45 | "--mobilebert_config_file", 46 | default=None, 47 | type=str, 48 | required=True, 49 | help="The config json file corresponding to the pre-trained MobileBERT model. \n" 50 | "This specifies the model architecture.", 51 | ) 52 | parser.add_argument( 53 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 54 | ) 55 | args = parser.parse_args() 56 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path) 57 | -------------------------------------------------------------------------------- /src/transformers/models/mobilebert/tokenization_mobilebert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright 2020 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Tokenization classes for MobileBERT.""" 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert import BertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": {"mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt"} 28 | } 29 | 30 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"mobilebert-uncased": 512} 31 | 32 | 33 | PRETRAINED_INIT_CONFIGURATION = {} 34 | 35 | 36 | class MobileBertTokenizer(BertTokenizer): 37 | r""" 38 | Construct a MobileBERT tokenizer. 39 | 40 | :class:`~transformers.MobileBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 41 | tokenization: punctuation splitting and wordpiece. 42 | 43 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 44 | parameters. 45 | """ 46 | 47 | vocab_files_names = VOCAB_FILES_NAMES 48 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 49 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 50 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 51 | -------------------------------------------------------------------------------- /src/transformers/models/mobilebert/tokenization_mobilebert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright 2020 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Tokenization classes for MobileBERT.""" 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert_fast import BertTokenizerFast 20 | from .tokenization_mobilebert import MobileBertTokenizer 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": {"mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt"}, 29 | "tokenizer_file": { 30 | "mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/tokenizer.json" 31 | }, 32 | } 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"mobilebert-uncased": 512} 35 | 36 | 37 | PRETRAINED_INIT_CONFIGURATION = {} 38 | 39 | 40 | class MobileBertTokenizerFast(BertTokenizerFast): 41 | r""" 42 | Construct a "fast" MobileBERT tokenizer (backed by HuggingFace's `tokenizers` library). 43 | 44 | :class:`~transformers.MobileBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 45 | end-to-end tokenization: punctuation splitting and wordpiece. 46 | 47 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 48 | parameters. 49 | """ 50 | 51 | vocab_files_names = VOCAB_FILES_NAMES 52 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 53 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 54 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 55 | slow_tokenizer_class = MobileBertTokenizer 56 | -------------------------------------------------------------------------------- /src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt 23 | from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME 24 | from transformers.utils import logging 25 | 26 | 27 | logging.set_verbosity_info() 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 44 | print(f"Save PyTorch model to {pytorch_weights_dump_path}") 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print(f"Save configuration file to {pytorch_config_dump_path}") 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # Required parameters 54 | parser.add_argument( 55 | "--openai_checkpoint_folder_path", 56 | default=None, 57 | type=str, 58 | required=True, 59 | help="Path to the TensorFlow checkpoint path.", 60 | ) 61 | parser.add_argument( 62 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 63 | ) 64 | parser.add_argument( 65 | "--openai_config_file", 66 | default="", 67 | type=str, 68 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 69 | "This specifies the model architecture.", 70 | ) 71 | args = parser.parse_args() 72 | convert_openai_checkpoint_to_pytorch( 73 | args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path 74 | ) 75 | -------------------------------------------------------------------------------- /src/transformers/models/phobert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_phobert": ["PhobertTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_phobert import PhobertTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /src/transformers/models/prophetnet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig"], 26 | "tokenization_prophetnet": ["ProphetNetTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_prophetnet"] = [ 31 | "PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "ProphetNetDecoder", 33 | "ProphetNetEncoder", 34 | "ProphetNetForCausalLM", 35 | "ProphetNetForConditionalGeneration", 36 | "ProphetNetModel", 37 | "ProphetNetPreTrainedModel", 38 | ] 39 | 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig 43 | from .tokenization_prophetnet import ProphetNetTokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_prophetnet import ( 47 | PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | ProphetNetDecoder, 49 | ProphetNetEncoder, 50 | ProphetNetForCausalLM, 51 | ProphetNetForConditionalGeneration, 52 | ProphetNetModel, 53 | ProphetNetPreTrainedModel, 54 | ) 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /src/transformers/models/rag/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_rag": ["RagConfig"], 26 | "retrieval_rag": ["RagRetriever"], 27 | "tokenization_rag": ["RagTokenizer"], 28 | } 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] 32 | 33 | if is_tf_available(): 34 | _import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] 35 | 36 | 37 | if TYPE_CHECKING: 38 | from .configuration_rag import RagConfig 39 | from .retrieval_rag import RagRetriever 40 | from .tokenization_rag import RagTokenizer 41 | 42 | if is_torch_available(): 43 | from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration 44 | 45 | if is_tf_available(): 46 | from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration 47 | 48 | else: 49 | import importlib 50 | import os 51 | import sys 52 | 53 | class _LazyModule(_BaseLazyModule): 54 | """ 55 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 56 | """ 57 | 58 | __file__ = globals()["__file__"] 59 | __path__ = [os.path.dirname(__file__)] 60 | 61 | def _get_module(self, module_name: str): 62 | return importlib.import_module("." + module_name, self.__name__) 63 | 64 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 65 | -------------------------------------------------------------------------------- /src/transformers/models/retribert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig"], 26 | "tokenization_retribert": ["RetriBertTokenizer"], 27 | } 28 | 29 | if is_tokenizers_available(): 30 | _import_structure["tokenization_retribert_fast"] = ["RetriBertTokenizerFast"] 31 | 32 | if is_torch_available(): 33 | _import_structure["modeling_retribert"] = [ 34 | "RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", 35 | "RetriBertModel", 36 | "RetriBertPreTrainedModel", 37 | ] 38 | 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig 42 | from .tokenization_retribert import RetriBertTokenizer 43 | 44 | if is_tokenizers_available(): 45 | from .tokenization_retribert_fast import RetriBertTokenizerFast 46 | 47 | if is_torch_available(): 48 | from .modeling_retribert import ( 49 | RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, 50 | RetriBertModel, 51 | RetriBertPreTrainedModel, 52 | ) 53 | 54 | else: 55 | import importlib 56 | import os 57 | import sys 58 | 59 | class _LazyModule(_BaseLazyModule): 60 | """ 61 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 62 | """ 63 | 64 | __file__ = globals()["__file__"] 65 | __path__ = [os.path.dirname(__file__)] 66 | 67 | def _get_module(self, module_name: str): 68 | return importlib.import_module("." + module_name, self.__name__) 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 71 | -------------------------------------------------------------------------------- /src/transformers/models/retribert/tokenization_retribert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RetriBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert import BertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt", 28 | } 29 | } 30 | 31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 32 | "yjernite/retribert-base-uncased": 512, 33 | } 34 | 35 | 36 | PRETRAINED_INIT_CONFIGURATION = { 37 | "yjernite/retribert-base-uncased": {"do_lower_case": True}, 38 | } 39 | 40 | 41 | class RetriBertTokenizer(BertTokenizer): 42 | r""" 43 | Constructs a RetriBERT tokenizer. 44 | 45 | :class:`~transformers.RetroBertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 46 | tokenization: punctuation splitting and wordpiece. 47 | 48 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 49 | parameters. 50 | """ 51 | 52 | vocab_files_names = VOCAB_FILES_NAMES 53 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 54 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 55 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 56 | model_input_names = ["input_ids", "attention_mask"] 57 | -------------------------------------------------------------------------------- /src/transformers/models/retribert/tokenization_retribert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RetriBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert_fast import BertTokenizerFast 19 | from .tokenization_retribert import RetriBertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt", 29 | }, 30 | "tokenizer_file": { 31 | "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/tokenizer.json", 32 | }, 33 | } 34 | 35 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 36 | "yjernite/retribert-base-uncased": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "yjernite/retribert-base-uncased": {"do_lower_case": True}, 42 | } 43 | 44 | 45 | class RetriBertTokenizerFast(BertTokenizerFast): 46 | r""" 47 | Construct a "fast" RetriBERT tokenizer (backed by HuggingFace's `tokenizers` library). 48 | 49 | :class:`~transformers.RetriBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 50 | end-to-end tokenization: punctuation splitting and wordpiece. 51 | 52 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 53 | parameters. 54 | """ 55 | 56 | vocab_files_names = VOCAB_FILES_NAMES 57 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 58 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 59 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 60 | slow_tokenizer_class = RetriBertTokenizer 61 | model_input_names = ["input_ids", "attention_mask"] 62 | -------------------------------------------------------------------------------- /src/transformers/models/roberta/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from ...utils import logging 19 | from ..bert.configuration_bert import BertConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json", 26 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json", 27 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json", 28 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json", 29 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json", 30 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json", 31 | } 32 | 33 | 34 | class RobertaConfig(BertConfig): 35 | r""" 36 | This is the configuration class to store the configuration of a :class:`~transformers.RobertaModel` or a 37 | :class:`~transformers.TFRobertaModel`. It is used to instantiate a RoBERTa model according to the specified 38 | arguments, defining the model architecture. 39 | 40 | 41 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 42 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 43 | 44 | The :class:`~transformers.RobertaConfig` class directly inherits :class:`~transformers.BertConfig`. It reuses the 45 | same defaults. Please check the parent class for more information. 46 | 47 | Examples:: 48 | 49 | >>> from transformers import RobertaConfig, RobertaModel 50 | 51 | >>> # Initializing a RoBERTa configuration 52 | >>> configuration = RobertaConfig() 53 | 54 | >>> # Initializing a model from the configuration 55 | >>> model = RobertaModel(configuration) 56 | 57 | >>> # Accessing the model configuration 58 | >>> configuration = model.config 59 | """ 60 | model_type = "roberta" 61 | 62 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): 63 | """Constructs RobertaConfig.""" 64 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 65 | -------------------------------------------------------------------------------- /src/transformers/models/squeezebert/tokenization_squeezebert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for SqueezeBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert import BertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt", 28 | "squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt", 29 | "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt", 30 | } 31 | } 32 | 33 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 34 | "squeezebert/squeezebert-uncased": 512, 35 | "squeezebert/squeezebert-mnli": 512, 36 | "squeezebert/squeezebert-mnli-headless": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "squeezebert/squeezebert-uncased": {"do_lower_case": True}, 42 | "squeezebert/squeezebert-mnli": {"do_lower_case": True}, 43 | "squeezebert/squeezebert-mnli-headless": {"do_lower_case": True}, 44 | } 45 | 46 | 47 | class SqueezeBertTokenizer(BertTokenizer): 48 | r""" 49 | Constructs a SqueezeBert tokenizer. 50 | 51 | :class:`~transformers.SqueezeBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 52 | tokenization: punctuation splitting + wordpiece. 53 | 54 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 55 | parameters. 56 | """ 57 | 58 | vocab_files_names = VOCAB_FILES_NAMES 59 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 60 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 61 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 62 | -------------------------------------------------------------------------------- /src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The T5 authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert T5 checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 21 | from transformers.utils import logging 22 | 23 | 24 | logging.set_verbosity_info() 25 | 26 | 27 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 28 | # Initialise PyTorch model 29 | config = T5Config.from_json_file(config_file) 30 | print(f"Building PyTorch model from configuration: {config}") 31 | model = T5ForConditionalGeneration(config) 32 | 33 | # Load weights from tf checkpoint 34 | load_tf_weights_in_t5(model, config, tf_checkpoint_path) 35 | 36 | # Save pytorch-model 37 | print(f"Save PyTorch model to {pytorch_dump_path}") 38 | model.save_pretrained(pytorch_dump_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | # Required parameters 44 | parser.add_argument( 45 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 46 | ) 47 | parser.add_argument( 48 | "--config_file", 49 | default=None, 50 | type=str, 51 | required=True, 52 | help="The config json file corresponding to the pre-trained T5 model. \n" 53 | "This specifies the model architecture.", 54 | ) 55 | parser.add_argument( 56 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 57 | ) 58 | args = parser.parse_args() 59 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 60 | -------------------------------------------------------------------------------- /src/transformers/models/tapas/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig"], 26 | "tokenization_tapas": ["TapasTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_tapas"] = [ 31 | "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "TapasForMaskedLM", 33 | "TapasForQuestionAnswering", 34 | "TapasForSequenceClassification", 35 | "TapasModel", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig 41 | from .tokenization_tapas import TapasTokenizer 42 | 43 | if is_torch_available(): 44 | from .modeling_tapas import ( 45 | TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, 46 | TapasForMaskedLM, 47 | TapasForQuestionAnswering, 48 | TapasForSequenceClassification, 49 | TapasModel, 50 | ) 51 | 52 | else: 53 | import importlib 54 | import os 55 | import sys 56 | 57 | class _LazyModule(_BaseLazyModule): 58 | """ 59 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 60 | """ 61 | 62 | __file__ = globals()["__file__"] 63 | __path__ = [os.path.dirname(__file__)] 64 | 65 | def _get_module(self, module_name: str): 66 | return importlib.import_module("." + module_name, self.__name__) 67 | 68 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 69 | -------------------------------------------------------------------------------- /src/transformers/models/vit/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], 25 | } 26 | 27 | if is_vision_available(): 28 | _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"] 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_vit"] = [ 32 | "VIT_PRETRAINED_MODEL_ARCHIVE_LIST", 33 | "ViTForImageClassification", 34 | "ViTModel", 35 | "ViTPreTrainedModel", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig 41 | 42 | if is_vision_available(): 43 | from .feature_extraction_vit import ViTFeatureExtractor 44 | 45 | if is_torch_available(): 46 | from .modeling_vit import ( 47 | VIT_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | ViTForImageClassification, 49 | ViTModel, 50 | ViTPreTrainedModel, 51 | ) 52 | 53 | 54 | else: 55 | import importlib 56 | import os 57 | import sys 58 | 59 | class _LazyModule(_BaseLazyModule): 60 | """ 61 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 62 | """ 63 | 64 | __file__ = globals()["__file__"] 65 | __path__ = [os.path.dirname(__file__)] 66 | 67 | def _get_module(self, module_name: str): 68 | return importlib.import_module("." + module_name, self.__name__) 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 71 | -------------------------------------------------------------------------------- /src/transformers/models/wav2vec2/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"], 25 | "feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"], 26 | "processing_wav2vec2": ["Wav2Vec2Processor"], 27 | "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], 28 | } 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_wav2vec2"] = [ 32 | "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", 33 | "Wav2Vec2ForCTC", 34 | "Wav2Vec2ForMaskedLM", 35 | "Wav2Vec2Model", 36 | "Wav2Vec2PreTrainedModel", 37 | ] 38 | 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config 42 | from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor 43 | from .processing_wav2vec2 import Wav2Vec2Processor 44 | from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer 45 | 46 | if is_torch_available(): 47 | from .modeling_wav2vec2 import ( 48 | WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, 49 | Wav2Vec2ForCTC, 50 | Wav2Vec2ForMaskedLM, 51 | Wav2Vec2Model, 52 | Wav2Vec2PreTrainedModel, 53 | ) 54 | 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /src/transformers/models/xlm_prophetnet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from ...file_utils import is_sentencepiece_available, is_torch_available 20 | from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig 21 | 22 | 23 | if is_sentencepiece_available(): 24 | from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer 25 | 26 | if is_torch_available(): 27 | from .modeling_xlm_prophetnet import ( 28 | XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, 29 | XLMProphetNetDecoder, 30 | XLMProphetNetEncoder, 31 | XLMProphetNetForCausalLM, 32 | XLMProphetNetForConditionalGeneration, 33 | XLMProphetNetModel, 34 | ) 35 | -------------------------------------------------------------------------------- /src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ XLM-ProphetNet model configuration """ 16 | 17 | 18 | from ...utils import logging 19 | from ..prophetnet.configuration_prophetnet import ProphetNetConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "microsoft/xprophetnet-large-wiki100-cased": "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json", 26 | } 27 | 28 | 29 | class XLMProphetNetConfig(ProphetNetConfig): 30 | """ 31 | This class overrides :class:`~transformers.ProphetNetConfig`. Please check the superclass for the appropriate 32 | documentation alongside usage examples. 33 | """ 34 | 35 | model_type = "xlm-prophetnet" 36 | -------------------------------------------------------------------------------- /src/transformers/models/xlm_roberta/configuration_xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XLM-RoBERTa configuration """ 17 | 18 | from ...utils import logging 19 | from ..roberta.configuration_roberta import RobertaConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/config.json", 26 | "xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/config.json", 27 | "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json", 28 | "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json", 29 | "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json", 30 | "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json", 31 | } 32 | 33 | 34 | class XLMRobertaConfig(RobertaConfig): 35 | """ 36 | This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate 37 | documentation alongside usage examples. 38 | """ 39 | 40 | model_type = "xlm-roberta" 41 | -------------------------------------------------------------------------------- /src/transformers/sagemaker/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .trainer_sm import SageMakerTrainer 20 | from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled 21 | -------------------------------------------------------------------------------- /src/transformers/sagemaker/trainer_sm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | 16 | from ..trainer import Trainer 17 | from ..utils import logging 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | 23 | class SageMakerTrainer(Trainer): 24 | def __init__(self, args=None, **kwargs): 25 | warnings.warn( 26 | "`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` " 27 | "instead.", 28 | FutureWarning, 29 | ) 30 | super().__init__(args=args, **kwargs) 31 | -------------------------------------------------------------------------------- /src/transformers/training_args_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from dataclasses import dataclass, field 17 | 18 | from .file_utils import add_start_docstrings 19 | from .training_args import TrainingArguments 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @dataclass 26 | @add_start_docstrings(TrainingArguments.__doc__) 27 | class Seq2SeqTrainingArguments(TrainingArguments): 28 | """ 29 | sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): 30 | Whether to use a `sortish sampler` or not. Only possible if the underlying datasets are `Seq2SeqDataset` for 31 | now but will become generally available in the near future. 32 | 33 | It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness for 34 | the training set. 35 | predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): 36 | Whether to use generate to calculate generative metrics (ROUGE, BLEU). 37 | """ 38 | 39 | sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) 40 | predict_with_generate: bool = field( 41 | default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} 42 | ) 43 | -------------------------------------------------------------------------------- /src/transformers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from packaging import version 18 | 19 | from .. import __version__ 20 | 21 | 22 | def check_min_version(min_version): 23 | if version.parse(__version__) < version.parse(min_version): 24 | if "dev" in min_version: 25 | error_message = ( 26 | "This example requires a source install from HuggingFace Transformers (see " 27 | "`https://huggingface.co/transformers/installation.html#installing-from-source`)," 28 | ) 29 | else: 30 | error_message = f"This example requires a minimum version of {min_version}," 31 | error_message += f" but the version found is {__version__}.\n" 32 | raise ImportError( 33 | error_message 34 | + ( 35 | "Check out https://huggingface.co/transformers/examples.html for the examples corresponding to other " 36 | "versions of HuggingFace Transformers." 37 | ) 38 | ) 39 | -------------------------------------------------------------------------------- /src/transformers/utils/dummy_sentencepiece_and_speech_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..file_utils import requires_backends 3 | 4 | 5 | class Speech2TextProcessor: 6 | def __init__(self, *args, **kwargs): 7 | requires_backends(self, ["sentencepiece", "speech"]) 8 | -------------------------------------------------------------------------------- /src/transformers/utils/dummy_sentencepiece_and_tokenizers_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..file_utils import requires_backends 3 | 4 | 5 | SLOW_TO_FAST_CONVERTERS = None 6 | 7 | 8 | def convert_slow_tokenizer(*args, **kwargs): 9 | requires_backends(convert_slow_tokenizer, ["sentencepiece", "tokenizers"]) 10 | -------------------------------------------------------------------------------- /src/transformers/utils/dummy_speech_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..file_utils import requires_backends 3 | 4 | 5 | class Speech2TextFeatureExtractor: 6 | def __init__(self, *args, **kwargs): 7 | requires_backends(self, ["speech"]) 8 | -------------------------------------------------------------------------------- /src/transformers/utils/dummy_vision_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..file_utils import requires_backends 3 | 4 | 5 | class ImageFeatureExtractionMixin: 6 | def __init__(self, *args, **kwargs): 7 | requires_backends(self, ["vision"]) 8 | 9 | 10 | class CLIPFeatureExtractor: 11 | def __init__(self, *args, **kwargs): 12 | requires_backends(self, ["vision"]) 13 | 14 | 15 | class CLIPProcessor: 16 | def __init__(self, *args, **kwargs): 17 | requires_backends(self, ["vision"]) 18 | 19 | 20 | class DeiTFeatureExtractor: 21 | def __init__(self, *args, **kwargs): 22 | requires_backends(self, ["vision"]) 23 | 24 | 25 | class ViTFeatureExtractor: 26 | def __init__(self, *args, **kwargs): 27 | requires_backends(self, ["vision"]) 28 | -------------------------------------------------------------------------------- /src/transformers/utils/model_parallel_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from math import ceil 17 | 18 | 19 | def assert_device_map(device_map, num_blocks): 20 | blocks = list(range(0, num_blocks)) 21 | 22 | device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] 23 | 24 | # Duplicate check 25 | duplicate_blocks = [] 26 | for i in device_map_blocks: 27 | if device_map_blocks.count(i) > 1 and i not in duplicate_blocks: 28 | duplicate_blocks.append(i) 29 | # Missing blocks 30 | missing_blocks = [i for i in blocks if i not in device_map_blocks] 31 | extra_blocks = [i for i in device_map_blocks if i not in blocks] 32 | 33 | assert len(duplicate_blocks) == 0, ( 34 | "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These " 35 | "attention blocks were specified more than once: " + str(duplicate_blocks) 36 | ) 37 | assert len(missing_blocks) == 0, ( 38 | "There are attention blocks for this model that are not specified in the device_map. Add these attention " 39 | "blocks to a device on the device_map: " + str(missing_blocks) 40 | ) 41 | assert ( 42 | len(extra_blocks) == 0 43 | ), "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str( 44 | extra_blocks 45 | ) 46 | 47 | 48 | def get_device_map(n_layers, devices): 49 | """Returns a dictionary of layers distributed evenly across all devices.""" 50 | layers = list(range(n_layers)) 51 | n_blocks = int(ceil(n_layers / len(devices))) 52 | layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) 53 | 54 | return dict(zip(devices, layers_list)) 55 | -------------------------------------------------------------------------------- /src/transformers/utils/modeling_auto_mapping.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify: models/auto/modeling_auto.py 3 | # 2. run: python utils/class_mapping_update.py 4 | from collections import OrderedDict 5 | 6 | 7 | MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( 8 | [ 9 | ("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"), 10 | ("BigBirdConfig", "BigBirdForQuestionAnswering"), 11 | ("ConvBertConfig", "ConvBertForQuestionAnswering"), 12 | ("LEDConfig", "LEDForQuestionAnswering"), 13 | ("DistilBertConfig", "DistilBertForQuestionAnswering"), 14 | ("AlbertConfig", "AlbertForQuestionAnswering"), 15 | ("CamembertConfig", "CamembertForQuestionAnswering"), 16 | ("BartConfig", "BartForQuestionAnswering"), 17 | ("MBartConfig", "MBartForQuestionAnswering"), 18 | ("LongformerConfig", "LongformerForQuestionAnswering"), 19 | ("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"), 20 | ("RobertaConfig", "RobertaForQuestionAnswering"), 21 | ("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"), 22 | ("BertConfig", "BertForQuestionAnswering"), 23 | ("XLNetConfig", "XLNetForQuestionAnsweringSimple"), 24 | ("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"), 25 | ("MegatronBertConfig", "MegatronBertForQuestionAnswering"), 26 | ("MobileBertConfig", "MobileBertForQuestionAnswering"), 27 | ("XLMConfig", "XLMForQuestionAnsweringSimple"), 28 | ("ElectraConfig", "ElectraForQuestionAnswering"), 29 | ("ReformerConfig", "ReformerForQuestionAnswering"), 30 | ("FunnelConfig", "FunnelForQuestionAnswering"), 31 | ("LxmertConfig", "LxmertForQuestionAnswering"), 32 | ("MPNetConfig", "MPNetForQuestionAnswering"), 33 | ("DebertaConfig", "DebertaForQuestionAnswering"), 34 | ("DebertaV2Config", "DebertaV2ForQuestionAnswering"), 35 | ("IBertConfig", "IBertForQuestionAnswering"), 36 | ] 37 | ) 38 | --------------------------------------------------------------------------------