├── LAMOL.sh ├── LAMOL_myadaptor.sh ├── LICENSE ├── README.md ├── data_attrs.json ├── dev.py ├── environment ├── fp16.py ├── fp16util.py ├── loss_scaler.py ├── metrics.py ├── mytransformers ├── __init__.py ├── activations.py ├── activations_tf.py ├── adapters │ ├── __init__.py │ ├── composition.py │ ├── configuration.py │ ├── heads.py │ ├── layer.py │ ├── loading.py │ ├── model_mixin.py │ ├── modeling.py │ ├── models │ │ ├── __init__.py │ │ ├── bart.py │ │ ├── bert.py │ │ ├── distilbert.py │ │ └── gpt2.py │ ├── training.py │ └── utils.py ├── benchmark │ ├── __init__.py │ ├── benchmark.py │ ├── benchmark_args.py │ ├── benchmark_args_tf.py │ ├── benchmark_args_utils.py │ ├── benchmark_tf.py │ └── benchmark_utils.py ├── commands │ ├── __init__.py │ ├── add_new_model.py │ ├── convert.py │ ├── download.py │ ├── env.py │ ├── lfs.py │ ├── run.py │ ├── serving.py │ ├── train.py │ ├── transformers_cli.py │ └── user.py ├── configuration_utils.py ├── convert_graph_to_onnx.py ├── convert_pytorch_checkpoint_to_tf2.py ├── convert_slow_tokenizer.py ├── convert_slow_tokenizers_checkpoints_to_fast.py ├── convert_tf_hub_seq_to_seq_bert_to_pytorch.py ├── data │ ├── __init__.py │ ├── data_collator.py │ ├── datasets │ │ ├── __init__.py │ │ ├── glue.py │ │ ├── language_modeling.py │ │ └── squad.py │ ├── metrics │ │ ├── __init__.py │ │ └── squad_metrics.py │ ├── processors │ │ ├── __init__.py │ │ ├── glue.py │ │ ├── squad.py │ │ ├── utils.py │ │ └── xnli.py │ └── test_generation_utils.py ├── dependency_versions_check.py ├── dependency_versions_table.py ├── feature_extraction_sequence_utils.py ├── feature_extraction_utils.py ├── file_utils.py ├── generation_beam_search.py ├── generation_logits_process.py ├── generation_stopping_criteria.py ├── generation_tf_utils.py ├── generation_utils.py ├── hf_api.py ├── hf_argparser.py ├── image_utils.py ├── integrations.py ├── modelcard.py ├── modeling_flax_pytorch_utils.py ├── modeling_flax_utils.py ├── modeling_outputs.py ├── modeling_tf_outputs.py ├── modeling_tf_pytorch_utils.py ├── modeling_tf_utils.py ├── modeling_utils.py ├── models │ ├── __init__.py │ ├── albert │ │ ├── __init__.py │ │ ├── configuration_albert.py │ │ ├── convert_albert_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_albert.py │ │ ├── modeling_tf_albert.py │ │ ├── tokenization_albert.py │ │ └── tokenization_albert_fast.py │ ├── auto │ │ ├── __init__.py │ │ ├── auto_factory.py │ │ ├── configuration_auto.py │ │ ├── modeling_auto.py │ │ ├── modeling_flax_auto.py │ │ ├── modeling_tf_auto.py │ │ └── tokenization_auto.py │ ├── bart │ │ ├── __init__.py │ │ ├── configuration_bart.py │ │ ├── convert_bart_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_bart.py │ │ ├── modeling_tf_bart.py │ │ ├── tokenization_bart.py │ │ └── tokenization_bart_fast.py │ ├── barthez │ │ ├── __init__.py │ │ ├── tokenization_barthez.py │ │ └── tokenization_barthez_fast.py │ ├── bert │ │ ├── __init__.py │ │ ├── configuration_bert.py │ │ ├── convert_bert_original_tf2_checkpoint_to_pytorch.py │ │ ├── convert_bert_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_bert_pytorch_checkpoint_to_original_tf.py │ │ ├── modeling_bert.py │ │ ├── modeling_flax_bert.py │ │ ├── modeling_tf_bert.py │ │ ├── tokenization_bert.py │ │ └── tokenization_bert_fast.py │ ├── bert_generation │ │ ├── __init__.py │ │ ├── configuration_bert_generation.py │ │ ├── modeling_bert_generation.py │ │ └── tokenization_bert_generation.py │ ├── bert_japanese │ │ ├── __init__.py │ │ └── tokenization_bert_japanese.py │ ├── bertweet │ │ ├── __init__.py │ │ └── tokenization_bertweet.py │ ├── big_bird │ │ ├── __init__.py │ │ ├── configuration_big_bird.py │ │ ├── convert_bigbird_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_big_bird.py │ │ └── tokenization_big_bird.py │ ├── blenderbot │ │ ├── __init__.py │ │ ├── configuration_blenderbot.py │ │ ├── convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_blenderbot.py │ │ ├── modeling_tf_blenderbot.py │ │ └── tokenization_blenderbot.py │ ├── blenderbot_small │ │ ├── __init__.py │ │ ├── configuration_blenderbot_small.py │ │ ├── modeling_blenderbot_small.py │ │ ├── modeling_tf_blenderbot_small.py │ │ ├── tokenization_blenderbot_small.py │ │ └── tokenization_blenderbot_small_fast.py │ ├── bort │ │ └── convert_bort_original_gluonnlp_checkpoint_to_pytorch.py │ ├── camembert │ │ ├── __init__.py │ │ ├── configuration_camembert.py │ │ ├── modeling_camembert.py │ │ ├── modeling_tf_camembert.py │ │ ├── tokenization_camembert.py │ │ └── tokenization_camembert_fast.py │ ├── convbert │ │ ├── __init__.py │ │ ├── configuration_convbert.py │ │ ├── convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py │ │ ├── modeling_convbert.py │ │ ├── modeling_tf_convbert.py │ │ ├── tokenization_convbert.py │ │ └── tokenization_convbert_fast.py │ ├── ctrl │ │ ├── __init__.py │ │ ├── configuration_ctrl.py │ │ ├── modeling_ctrl.py │ │ ├── modeling_tf_ctrl.py │ │ └── tokenization_ctrl.py │ ├── deberta │ │ ├── __init__.py │ │ ├── configuration_deberta.py │ │ ├── modeling_deberta.py │ │ └── tokenization_deberta.py │ ├── deberta_v2 │ │ ├── __init__.py │ │ ├── configuration_deberta_v2.py │ │ ├── modeling_deberta_v2.py │ │ └── tokenization_deberta_v2.py │ ├── dialogpt │ │ ├── __init__.py │ │ └── convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py │ ├── distilbert │ │ ├── __init__.py │ │ ├── configuration_distilbert.py │ │ ├── modeling_distilbert.py │ │ ├── modeling_tf_distilbert.py │ │ ├── tokenization_distilbert.py │ │ └── tokenization_distilbert_fast.py │ ├── dpr │ │ ├── __init__.py │ │ ├── configuration_dpr.py │ │ ├── convert_dpr_original_checkpoint_to_pytorch.py │ │ ├── modeling_dpr.py │ │ ├── modeling_tf_dpr.py │ │ ├── tokenization_dpr.py │ │ └── tokenization_dpr_fast.py │ ├── electra │ │ ├── __init__.py │ │ ├── configuration_electra.py │ │ ├── convert_electra_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_electra.py │ │ ├── modeling_tf_electra.py │ │ ├── tokenization_electra.py │ │ └── tokenization_electra_fast.py │ ├── encoder_decoder │ │ ├── __init__.py │ │ ├── configuration_encoder_decoder.py │ │ └── modeling_encoder_decoder.py │ ├── flaubert │ │ ├── __init__.py │ │ ├── configuration_flaubert.py │ │ ├── modeling_flaubert.py │ │ ├── modeling_tf_flaubert.py │ │ └── tokenization_flaubert.py │ ├── fsmt │ │ ├── __init__.py │ │ ├── configuration_fsmt.py │ │ ├── convert_fsmt_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_fsmt.py │ │ └── tokenization_fsmt.py │ ├── funnel │ │ ├── __init__.py │ │ ├── configuration_funnel.py │ │ ├── convert_funnel_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_funnel.py │ │ ├── modeling_tf_funnel.py │ │ ├── tokenization_funnel.py │ │ └── tokenization_funnel_fast.py │ ├── gpt2 │ │ ├── __init__.py │ │ ├── configuration_gpt2.py │ │ ├── convert_gpt2_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_gpt2.py │ │ ├── modeling_tf_gpt2.py │ │ ├── tokenization_gpt2.py │ │ └── tokenization_gpt2_fast.py │ ├── gpt_neo │ │ ├── __init__.py │ │ ├── configuration_gpt_neo.py │ │ ├── convert_gpt_neo_mesh_tf_to_pytorch.py │ │ └── modeling_gpt_neo.py │ ├── herbert │ │ ├── __init__.py │ │ ├── tokenization_herbert.py │ │ └── tokenization_herbert_fast.py │ ├── ibert │ │ ├── __init__.py │ │ ├── configuration_ibert.py │ │ ├── modeling_ibert.py │ │ └── quant_modules.py │ ├── layoutlm │ │ ├── __init__.py │ │ ├── configuration_layoutlm.py │ │ ├── modeling_layoutlm.py │ │ ├── modeling_tf_layoutlm.py │ │ ├── tokenization_layoutlm.py │ │ └── tokenization_layoutlm_fast.py │ ├── led │ │ ├── __init__.py │ │ ├── configuration_led.py │ │ ├── modeling_led.py │ │ ├── modeling_tf_led.py │ │ ├── tokenization_led.py │ │ └── tokenization_led_fast.py │ ├── longformer │ │ ├── __init__.py │ │ ├── configuration_longformer.py │ │ ├── convert_longformer_original_pytorch_lightning_to_pytorch.py │ │ ├── modeling_longformer.py │ │ ├── modeling_tf_longformer.py │ │ ├── tokenization_longformer.py │ │ └── tokenization_longformer_fast.py │ ├── lxmert │ │ ├── __init__.py │ │ ├── configuration_lxmert.py │ │ ├── convert_lxmert_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_lxmert.py │ │ ├── modeling_tf_lxmert.py │ │ ├── tokenization_lxmert.py │ │ └── tokenization_lxmert_fast.py │ ├── m2m_100 │ │ ├── __init__.py │ │ ├── configuration_m2m_100.py │ │ ├── convert_m2m100_original_checkpoint_to_pytorch.py │ │ ├── modeling_m2m_100.py │ │ └── tokenization_m2m_100.py │ ├── marian │ │ ├── __init__.py │ │ ├── configuration_marian.py │ │ ├── convert_marian_tatoeba_to_pytorch.py │ │ ├── convert_marian_to_pytorch.py │ │ ├── modeling_marian.py │ │ ├── modeling_tf_marian.py │ │ └── tokenization_marian.py │ ├── mbart │ │ ├── __init__.py │ │ ├── configuration_mbart.py │ │ ├── convert_mbart_original_checkpoint_to_pytorch.py │ │ ├── modeling_mbart.py │ │ ├── modeling_tf_mbart.py │ │ ├── tokenization_mbart.py │ │ ├── tokenization_mbart50.py │ │ ├── tokenization_mbart50_fast.py │ │ └── tokenization_mbart_fast.py │ ├── mmbt │ │ ├── __init__.py │ │ ├── configuration_mmbt.py │ │ └── modeling_mmbt.py │ ├── mobilebert │ │ ├── __init__.py │ │ ├── configuration_mobilebert.py │ │ ├── convert_mobilebert_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_mobilebert.py │ │ ├── modeling_tf_mobilebert.py │ │ ├── tokenization_mobilebert.py │ │ └── tokenization_mobilebert_fast.py │ ├── mpnet │ │ ├── __init__.py │ │ ├── configuration_mpnet.py │ │ ├── modeling_mpnet.py │ │ ├── modeling_tf_mpnet.py │ │ ├── tokenization_mpnet.py │ │ └── tokenization_mpnet_fast.py │ ├── mt5 │ │ ├── __init__.py │ │ ├── configuration_mt5.py │ │ ├── modeling_mt5.py │ │ └── modeling_tf_mt5.py │ ├── openai │ │ ├── __init__.py │ │ ├── configuration_openai.py │ │ ├── convert_openai_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_openai.py │ │ ├── modeling_tf_openai.py │ │ ├── tokenization_openai.py │ │ └── tokenization_openai_fast.py │ ├── pegasus │ │ ├── __init__.py │ │ ├── configuration_pegasus.py │ │ ├── convert_pegasus_tf_to_pytorch.py │ │ ├── modeling_pegasus.py │ │ ├── modeling_tf_pegasus.py │ │ ├── tokenization_pegasus.py │ │ └── tokenization_pegasus_fast.py │ ├── phobert │ │ ├── __init__.py │ │ └── tokenization_phobert.py │ ├── prophetnet │ │ ├── __init__.py │ │ ├── configuration_prophetnet.py │ │ ├── convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_prophetnet.py │ │ └── tokenization_prophetnet.py │ ├── rag │ │ ├── __init__.py │ │ ├── configuration_rag.py │ │ ├── modeling_rag.py │ │ ├── modeling_tf_rag.py │ │ ├── retrieval_rag.py │ │ └── tokenization_rag.py │ ├── reformer │ │ ├── __init__.py │ │ ├── configuration_reformer.py │ │ ├── convert_reformer_trax_checkpoint_to_pytorch.py │ │ ├── modeling_reformer.py │ │ ├── tokenization_reformer.py │ │ └── tokenization_reformer_fast.py │ ├── retribert │ │ ├── __init__.py │ │ ├── configuration_retribert.py │ │ ├── modeling_retribert.py │ │ ├── tokenization_retribert.py │ │ └── tokenization_retribert_fast.py │ ├── roberta │ │ ├── __init__.py │ │ ├── configuration_roberta.py │ │ ├── convert_roberta_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_flax_roberta.py │ │ ├── modeling_roberta.py │ │ ├── modeling_tf_roberta.py │ │ ├── tokenization_roberta.py │ │ └── tokenization_roberta_fast.py │ ├── speech_to_text │ │ ├── __init__.py │ │ ├── configuration_speech_to_text.py │ │ ├── convert_s2t_fairseq_to_tfms.py │ │ ├── feature_extraction_speech_to_text.py │ │ ├── modeling_speech_to_text.py │ │ ├── processing_speech_to_text.py │ │ └── tokenization_speech_to_text.py │ ├── squeezebert │ │ ├── __init__.py │ │ ├── configuration_squeezebert.py │ │ ├── modeling_squeezebert.py │ │ ├── tokenization_squeezebert.py │ │ └── tokenization_squeezebert_fast.py │ ├── t5 │ │ ├── __init__.py │ │ ├── configuration_t5.py │ │ ├── convert_t5_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_t5.py │ │ ├── modeling_tf_t5.py │ │ ├── tokenization_t5.py │ │ └── tokenization_t5_fast.py │ ├── tapas │ │ ├── __init__.py │ │ ├── configuration_tapas.py │ │ ├── convert_tapas_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_tapas.py │ │ └── tokenization_tapas.py │ ├── transfo_xl │ │ ├── __init__.py │ │ ├── configuration_transfo_xl.py │ │ ├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_tf_transfo_xl.py │ │ ├── modeling_tf_transfo_xl_utilities.py │ │ ├── modeling_transfo_xl.py │ │ ├── modeling_transfo_xl_utilities.py │ │ └── tokenization_transfo_xl.py │ ├── vit │ │ ├── __init__.py │ │ ├── configuration_vit.py │ │ ├── convert_vit_timm_to_pytorch.py │ │ ├── feature_extraction_vit.py │ │ └── modeling_vit.py │ ├── wav2vec2 │ │ ├── __init__.py │ │ ├── configuration_wav2vec2.py │ │ ├── convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py │ │ ├── feature_extraction_wav2vec2.py │ │ ├── modeling_wav2vec2.py │ │ ├── processing_wav2vec2.py │ │ └── tokenization_wav2vec2.py │ ├── xlm │ │ ├── __init__.py │ │ ├── configuration_xlm.py │ │ ├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py │ │ ├── modeling_tf_xlm.py │ │ ├── modeling_xlm.py │ │ └── tokenization_xlm.py │ ├── xlm_prophetnet │ │ ├── __init__.py │ │ ├── configuration_xlm_prophetnet.py │ │ ├── modeling_xlm_prophetnet.py │ │ └── tokenization_xlm_prophetnet.py │ ├── xlm_roberta │ │ ├── __init__.py │ │ ├── configuration_xlm_roberta.py │ │ ├── modeling_tf_xlm_roberta.py │ │ ├── modeling_xlm_roberta.py │ │ ├── tokenization_xlm_roberta.py │ │ └── tokenization_xlm_roberta_fast.py │ └── xlnet │ │ ├── __init__.py │ │ ├── configuration_xlnet.py │ │ ├── convert_xlnet_original_tf_checkpoint_to_pytorch.py │ │ ├── modeling_tf_xlnet.py │ │ ├── modeling_xlnet.py │ │ ├── tokenization_xlnet.py │ │ └── tokenization_xlnet_fast.py ├── optimization.py ├── optimization_tf.py ├── pipelines │ ├── __init__.py │ ├── base.py │ ├── conversational.py │ ├── feature_extraction.py │ ├── fill_mask.py │ ├── question_answering.py │ ├── table_question_answering.py │ ├── text2text_generation.py │ ├── text_classification.py │ ├── text_generation.py │ ├── token_classification.py │ └── zero_shot_classification.py ├── sagemaker │ ├── __init__.py │ ├── trainer_sm.py │ └── training_args_sm.py ├── testing_utils.py ├── tokenization_utils.py ├── tokenization_utils_base.py ├── tokenization_utils_fast.py ├── trainer.py ├── trainer_callback.py ├── trainer_pt_utils.py ├── trainer_seq2seq.py ├── trainer_tf.py ├── trainer_utils.py ├── training_args.py ├── training_args_seq2seq.py ├── training_args_tf.py └── utils │ ├── __init__.py │ ├── dummy_flax_objects.py │ ├── dummy_pt_objects.py │ ├── dummy_sentencepiece_objects.py │ ├── dummy_tf_objects.py │ ├── dummy_tokenizers_objects.py │ ├── dummy_vision_objects.py │ ├── hp_naming.py │ ├── imagenet_classes.py │ ├── logging.py │ ├── model_parallel_utils.py │ ├── modeling_auto_mapping.py │ ├── notebook.py │ ├── sentencepiece_model_pb2.py │ └── versions.py ├── parallel.py ├── preprocess.py ├── regularizers.py ├── regularizers_myadaptor.py ├── requirements.txt ├── scheduler.py ├── settings.py ├── settings_myadaptor.py ├── test.py ├── test.sh ├── test_myadaptor.py ├── test_myadaptor.sh ├── train.py ├── train.sh ├── train_myadaptor.py ├── train_myadaptor.sh ├── utils.py ├── utils_myadaptor.py └── v2.gif /LAMOL.sh: -------------------------------------------------------------------------------- 1 | A=e2enlg 2 | B=rnnlg.rest 3 | C=rnnlg.hotel 4 | D=rnnlg.tv 5 | E=rnnlg.laptop 6 | F=woz.en 7 | G=wikisql 8 | H=cnn_dailymail 9 | 10 | EXP=CL_GEN 11 | 12 | # Example for Finetune baseline 13 | 14 | SEED=1 15 | bash train.sh --seq_train_type finetune --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $A $G $C $F $B > 6ft-log.train.NLG.$SEED 2>&1 16 | sleep 30 17 | bash test.sh --task_test 5 --seq_train_type finetune --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $A $G $C $F $B > 6ft-log.test.NLG.$SEED 2>&1 18 | sleep 30 19 | 20 | SEED=2 21 | bash train.sh --seq_train_type finetune --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $A $G $C $F $B > 6ft-log.train.NLG.$SEED 2>&1 22 | sleep 30 23 | bash test.sh --task_test 5 --seq_train_type finetune --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $A $G $C $F $B > 6ft-log.test.NLG.$SEED 2>&1 24 | 25 | 26 | # Example for (online) EWC baseline, use '--reg_lambda' to tune reg cofficient, 1e6 is selected from {1e4, 1e5, 1e6, 1e7} 27 | 28 | SEED=1 29 | bash train.sh --n_train_epochs 9 --reg_lambda 1e6 --seq_train_type ewc --model_name gpt2 --add_task_tokens --seed $SEED --tasks $E $D $C $B $A > 2ewc-log.train.NLG.$SEED 2>&1 30 | sleep 30 31 | bash test.sh --task_test 5 --seq_train_type ewc --model_name gpt2 --add_task_tokens --seed $SEED --tasks $E $D $C $B $A > 2ewc-log.test.ewc.NLG.$SEED 2>&1 32 | sleep 30 33 | 34 | SEED=2 35 | bash train.sh --n_train_epochs 9 --reg_lambda 1e6 --seq_train_type ewc --model_name gpt2 --add_task_tokens --seed $SEED --tasks $E $D $C $B $A > 2ewc-log.train.ewc.NLG.$SEED 2>&1 36 | sleep 30 37 | bash test.sh --task_test 5 --seq_train_type ewc --model_name gpt2 --add_task_tokens --seed $SEED --tasks $E $D $C $B $A > 2ewc-log.test.ewc.NLG.$SEED 2>&1 38 | 39 | 40 | # Example for LAMOL baseline, use '--lamaml' to increase replay frequency 41 | 42 | SEED=1 43 | bash train.sh --lamaml --seq_train_type lll --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $E $D $C $B $A > 2lamol-log.train.lamol.NLG.$SEED 2>&1 44 | sleep 30 45 | bash test.sh --task_test 5 --seq_train_type lll --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $E $D $C $B $A > 2lamol-log.test.lamol.NLG.$SEED 2>&1 46 | sleep 30 47 | 48 | SEED=2 49 | bash train.sh --lamaml --seq_train_type lll --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $E $D $C $B $A > 2lamol-log.train.lamol.NLG.$SEED 2>&1 50 | sleep 30 51 | bash test.sh --task_test 5 --seq_train_type lll --model_name gpt2 --add_task_tokens --seed $SEED --gen_lm_sample_percentage 0.2 --tasks $E $D $C $B $A > 2lamol-log.test.lamol.NLG.$SEED 2>&1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 GT-SALT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_attrs.json: -------------------------------------------------------------------------------- 1 | {"squad1": {"train": {"data_size": 87599, "max_a_len": 77}, "eval": {"data_size": 34726, "max_a_len": 70}, "test": {"data_size": 34726, "max_a_len": 70}}, "squad2": {"train": {"data_size": 130319, "max_a_len": 77}, "eval": {"data_size": 26247, "max_a_len": 41}, "test": {"data_size": 26247, "max_a_len": 41}}, "iwslt.en.de": {"train": {"data_size": 196884, "max_a_len": 1396}, "eval": {"data_size": 993, "max_a_len": 235}, "test": {"data_size": 1305, "max_a_len": 187}}, "cnn_dailymail": {"train": {"data_size": 6604, "max_a_len": 2688}, "eval": {"data_size": 2250, "max_a_len": 2123}, "test": {"data_size": 2250, "max_a_len": 870}}, "multinli.in.out": {"train": {"data_size": 392702, "max_a_len": 3}, "eval": {"data_size": 20000, "max_a_len": 3}, "test": {"data_size": 19643, "max_a_len": 1}}, "sst": {"train": {"data_size": 6920, "max_a_len": 1}, "eval": {"data_size": 872, "max_a_len": 1}, "test": {"data_size": 1821, "max_a_len": 1}}, "srl": {"train": {"data_size": 6414, "max_a_len": 65}, "eval": {"data_size": 2183, "max_a_len": 64}, "test": {"data_size": 2201, "max_a_len": 54}}, "zre": {"train": {"data_size": 840000, "max_a_len": 123}, "eval": {"data_size": 600, "max_a_len": 13}, "test": {"data_size": 12000, "max_a_len": 24}}, "woz.en": {"train": {"data_size": 2536, "max_a_len": 17}, "eval": {"data_size": 830, "max_a_len": 17}, "test": {"data_size": 1646, "max_a_len": 16}}, "wikisql": {"train": {"data_size": 6525, "max_a_len": 157}, "eval": {"data_size": 8421, "max_a_len": 82}, "test": {"data_size": 15878, "max_a_len": 85}}, "schema": {"train": {"data_size": 80, "max_a_len": 4}, "eval": {"data_size": 82, "max_a_len": 3}, "test": {"data_size": 100, "max_a_len": 4}}, 2 | "ag": {"train": {"data_size": 7600, "max_a_len": 4}, "eval": {"data_size": 7600, "max_a_len": 4}, "test": {"data_size": 7600, "max_a_len": 4}}, 3 | "dbpedia": {"train": {"data_size": 115000, "max_a_len": 6}, "eval": {"data_size": 7600, "max_a_len": 6}, "test": {"data_size": 7600, "max_a_len": 6}}, 4 | "yahoo": {"train": {"data_size": 115000, "max_a_len": 4}, "eval": {"data_size": 7600, "max_a_len": 4}, "test": {"data_size": 7600, "max_a_len": 4}}, 5 | "amazon": {"train": {"data_size": 115000, "max_a_len": 2}, "eval": {"data_size": 7600, "max_a_len": 2}, "test": {"data_size": 7600, "max_a_len": 2}}, 6 | "yelp": {"train": {"data_size": 115000, "max_a_len": 2}, "eval": {"data_size": 7600, "max_a_len": 2}, "test": {"data_size": 7600, "max_a_len": 2}}, 7 | "e2enlg": {"train": {"data_size": 6000, "max_a_len": 100}, "eval": {"data_size": 2000, "max_a_len": 100}, "test": {"data_size": 2000, "max_a_len": 100}}, 8 | "rnnlg.tv": {"train": {"data_size": 8442, "max_a_len": 100}, "eval": {"data_size": 1407, "max_a_len": 100}, "test": {"data_size": 1407, "max_a_len": 100}}, 9 | "rnnlg.hotel": {"train": {"data_size": 6446, "max_a_len": 100}, "eval": {"data_size": 1075, "max_a_len": 100}, "test": {"data_size": 1075, "max_a_len": 100}}, 10 | "rnnlg.rest": {"train": {"data_size": 6228, "max_a_len": 100}, "eval": {"data_size": 1039, "max_a_len": 100}, "test": {"data_size": 1039, "max_a_len": 100}}, 11 | "rnnlg.laptop": {"train": {"data_size": 7944, "max_a_len": 100}, "eval": {"data_size": 2649, "max_a_len": 100}, "test": {"data_size": 2649, "max_a_len": 100}} 12 | } 13 | -------------------------------------------------------------------------------- /environment: -------------------------------------------------------------------------------- 1 | export DATA_DIR='data' 2 | export MODEL_ROOT_DIR='YOUR_MODEL_DIR' 3 | -------------------------------------------------------------------------------- /mytransformers/activations_tf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import tensorflow as tf 18 | from packaging import version 19 | 20 | 21 | def _gelu(x): 22 | """ 23 | Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when 24 | initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 25 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see 26 | https://arxiv.org/abs/1606.08415 27 | """ 28 | x = tf.convert_to_tensor(x) 29 | cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) 30 | 31 | return x * cdf 32 | 33 | 34 | def _gelu_new(x): 35 | """ 36 | Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 37 | 38 | Args: 39 | x: float Tensor to perform activation 40 | 41 | Returns: 42 | `x` with the GELU activation applied. 43 | """ 44 | x = tf.convert_to_tensor(x) 45 | pi = tf.cast(math.pi, x.dtype) 46 | coeff = tf.cast(0.044715, x.dtype) 47 | cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) 48 | 49 | return x * cdf 50 | 51 | 52 | def mish(x): 53 | x = tf.convert_to_tensor(x) 54 | 55 | return x * tf.tanh(tf.math.softplus(x)) 56 | 57 | 58 | def gelu_fast(x): 59 | x = tf.convert_to_tensor(x) 60 | coeff1 = tf.cast(0.7978845608, x.dtype) 61 | coeff2 = tf.cast(0.044715, x.dtype) 62 | 63 | return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) 64 | 65 | 66 | if version.parse(tf.version.VERSION) >= version.parse("2.4"): 67 | 68 | def approximate_gelu_wrap(x): 69 | return tf.keras.activations.gelu(x, approximate=True) 70 | 71 | gelu = tf.keras.activations.gelu 72 | gelu_new = approximate_gelu_wrap 73 | else: 74 | gelu = _gelu 75 | gelu_new = _gelu_new 76 | 77 | 78 | ACT2FN = { 79 | "gelu": gelu, 80 | "relu": tf.keras.activations.relu, 81 | "swish": tf.keras.activations.swish, 82 | "silu": tf.keras.activations.swish, 83 | "gelu_new": gelu_new, 84 | "mish": mish, 85 | "tanh": tf.keras.activations.tanh, 86 | "gelu_fast": gelu_fast, 87 | } 88 | 89 | 90 | def get_tf_activation(activation_string): 91 | if activation_string in ACT2FN: 92 | return ACT2FN[activation_string] 93 | else: 94 | raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") 95 | -------------------------------------------------------------------------------- /mytransformers/adapters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SALT-NLP/Adaptive-Compositional-Modules/357aa2d6d1cd97ea03aeaddbd5372a1aeecbbe4c/mytransformers/adapters/__init__.py -------------------------------------------------------------------------------- /mytransformers/adapters/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SALT-NLP/Adaptive-Compositional-Modules/357aa2d6d1cd97ea03aeaddbd5372a1aeecbbe4c/mytransformers/adapters/models/__init__.py -------------------------------------------------------------------------------- /mytransformers/adapters/training.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class AdapterArguments: 7 | """ 8 | The subset of arguments related to adapter training. 9 | """ 10 | 11 | train_adapter: bool = field(default=False, metadata={"help": "Train an adapter instead of the full model."}) 12 | load_adapter: Optional[str] = field( 13 | default="", metadata={"help": "Pre-trained adapter module to be loaded from Hub."} 14 | ) 15 | adapter_config: Optional[str] = field( 16 | default="pfeiffer", metadata={"help": "Adapter configuration. Either an identifier or a path to a file."} 17 | ) 18 | adapter_non_linearity: Optional[str] = field( 19 | default=None, metadata={"help": "Override the non-linearity of the adapter configuration."} 20 | ) 21 | adapter_reduction_factor: Optional[int] = field( 22 | default=None, metadata={"help": "Override the reduction factor of the adapter configuration."} 23 | ) 24 | language: Optional[str] = field(default=None, metadata={"help": "The training language, e.g. 'en' for English."}) 25 | 26 | 27 | @dataclass 28 | class MultiLingAdapterArguments(AdapterArguments): 29 | """ 30 | Arguemnts related to adapter training, extended by arguments for multilingual setups. 31 | """ 32 | 33 | load_lang_adapter: Optional[str] = field( 34 | default=None, metadata={"help": "Pre-trained language adapter module to be loaded from Hub."} 35 | ) 36 | lang_adapter_config: Optional[str] = field( 37 | default=None, metadata={"help": "Language adapter configuration. Either an identifier or a path to a file."} 38 | ) 39 | lang_adapter_non_linearity: Optional[str] = field( 40 | default=None, metadata={"help": "Override the non-linearity of the language adapter configuration."} 41 | ) 42 | lang_adapter_reduction_factor: Optional[int] = field( 43 | default=None, metadata={"help": "Override the reduction factor of the language adapter configuration."} 44 | ) 45 | -------------------------------------------------------------------------------- /mytransformers/benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SALT-NLP/Adaptive-Compositional-Modules/357aa2d6d1cd97ea03aeaddbd5372a1aeecbbe4c/mytransformers/benchmark/__init__.py -------------------------------------------------------------------------------- /mytransformers/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from argparse import ArgumentParser 17 | 18 | 19 | class BaseTransformersCLICommand(ABC): 20 | @staticmethod 21 | @abstractmethod 22 | def register_subcommand(parser: ArgumentParser): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def run(self): 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /mytransformers/commands/download.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from argparse import ArgumentParser 16 | 17 | from . import BaseTransformersCLICommand 18 | 19 | 20 | def download_command_factory(args): 21 | return DownloadCommand(args.model, args.cache_dir, args.force) 22 | 23 | 24 | class DownloadCommand(BaseTransformersCLICommand): 25 | @staticmethod 26 | def register_subcommand(parser: ArgumentParser): 27 | download_parser = parser.add_parser("download") 28 | download_parser.add_argument( 29 | "--cache-dir", type=str, default=None, help="Path to location to store the models" 30 | ) 31 | download_parser.add_argument( 32 | "--force", action="store_true", help="Force the model to be download even if already in cache-dir" 33 | ) 34 | download_parser.add_argument("model", type=str, help="Name of the model to download") 35 | download_parser.set_defaults(func=download_command_factory) 36 | 37 | def __init__(self, model: str, cache: str, force: bool): 38 | self._model = model 39 | self._cache = cache 40 | self._force = force 41 | 42 | def run(self): 43 | from ..models.auto import AutoModel, AutoTokenizer 44 | 45 | AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) 46 | AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) 47 | -------------------------------------------------------------------------------- /mytransformers/commands/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import platform 16 | from argparse import ArgumentParser 17 | 18 | from .. import __version__ as version 19 | from ..file_utils import is_tf_available, is_torch_available 20 | from . import BaseTransformersCLICommand 21 | 22 | 23 | def info_command_factory(_): 24 | return EnvironmentCommand() 25 | 26 | 27 | class EnvironmentCommand(BaseTransformersCLICommand): 28 | @staticmethod 29 | def register_subcommand(parser: ArgumentParser): 30 | download_parser = parser.add_parser("env") 31 | download_parser.set_defaults(func=info_command_factory) 32 | 33 | def run(self): 34 | pt_version = "not installed" 35 | pt_cuda_available = "NA" 36 | if is_torch_available(): 37 | import torch 38 | 39 | pt_version = torch.__version__ 40 | pt_cuda_available = torch.cuda.is_available() 41 | 42 | tf_version = "not installed" 43 | tf_cuda_available = "NA" 44 | if is_tf_available(): 45 | import tensorflow as tf 46 | 47 | tf_version = tf.__version__ 48 | try: 49 | # deprecated in v2.1 50 | tf_cuda_available = tf.test.is_gpu_available() 51 | except AttributeError: 52 | # returns list of devices, convert to bool 53 | tf_cuda_available = bool(tf.config.list_physical_devices("GPU")) 54 | 55 | info = { 56 | "`transformers` version": version, 57 | "Platform": platform.platform(), 58 | "Python version": platform.python_version(), 59 | "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", 60 | "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})", 61 | "Using GPU in script?": "", 62 | "Using distributed or parallel set-up in script?": "", 63 | } 64 | 65 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") 66 | print(self.format_dict(info)) 67 | 68 | return info 69 | 70 | @staticmethod 71 | def format_dict(d): 72 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" 73 | -------------------------------------------------------------------------------- /mytransformers/commands/transformers_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from argparse import ArgumentParser 17 | 18 | from .add_new_model import AddNewModelCommand 19 | from .convert import ConvertCommand 20 | from .download import DownloadCommand 21 | from .env import EnvironmentCommand 22 | from .lfs import LfsCommands 23 | from .run import RunCommand 24 | from .serving import ServeCommand 25 | from .user import UserCommands 26 | 27 | 28 | def main(): 29 | parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli []") 30 | commands_parser = parser.add_subparsers(help="transformers-cli command helpers") 31 | 32 | # Register commands 33 | ConvertCommand.register_subcommand(commands_parser) 34 | DownloadCommand.register_subcommand(commands_parser) 35 | EnvironmentCommand.register_subcommand(commands_parser) 36 | RunCommand.register_subcommand(commands_parser) 37 | ServeCommand.register_subcommand(commands_parser) 38 | UserCommands.register_subcommand(commands_parser) 39 | AddNewModelCommand.register_subcommand(commands_parser) 40 | LfsCommands.register_subcommand(commands_parser) 41 | 42 | # Let's go 43 | args = parser.parse_args() 44 | 45 | if not hasattr(args, "func"): 46 | parser.print_help() 47 | exit(1) 48 | 49 | # Run 50 | service = args.func(args) 51 | service.run() 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /mytransformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Seq2Seq TF Hub checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | from . import ( 21 | BertConfig, 22 | BertGenerationConfig, 23 | BertGenerationDecoder, 24 | BertGenerationEncoder, 25 | load_tf_weights_in_bert_generation, 26 | logging, 27 | ) 28 | 29 | 30 | logging.set_verbosity_info() 31 | 32 | 33 | def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder): 34 | # Initialise PyTorch model 35 | bert_config = BertConfig.from_pretrained( 36 | "bert-large-cased", 37 | vocab_size=vocab_size, 38 | max_position_embeddings=512, 39 | is_decoder=True, 40 | add_cross_attention=True, 41 | ) 42 | bert_config_dict = bert_config.to_dict() 43 | del bert_config_dict["type_vocab_size"] 44 | config = BertGenerationConfig(**bert_config_dict) 45 | if is_encoder: 46 | model = BertGenerationEncoder(config) 47 | else: 48 | model = BertGenerationDecoder(config) 49 | print(f"Building PyTorch model from configuration: {config}") 50 | 51 | # Load weights from tf checkpoint 52 | load_tf_weights_in_bert_generation( 53 | model, 54 | tf_hub_path, 55 | model_class="bert", 56 | is_encoder_named_decoder=is_encoder_named_decoder, 57 | is_encoder=is_encoder, 58 | ) 59 | 60 | # Save pytorch-model 61 | print(f"Save PyTorch model and config to {pytorch_dump_path}") 62 | model.save_pretrained(pytorch_dump_path) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | # Required parameters 68 | parser.add_argument( 69 | "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 70 | ) 71 | parser.add_argument( 72 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 73 | ) 74 | parser.add_argument( 75 | "--is_encoder_named_decoder", 76 | action="store_true", 77 | help="If decoder has to be renamed to encoder in PyTorch model.", 78 | ) 79 | parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.") 80 | parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model") 81 | args = parser.parse_args() 82 | convert_tf_checkpoint_to_pytorch( 83 | args.tf_hub_path, 84 | args.pytorch_dump_path, 85 | args.is_encoder_named_decoder, 86 | args.vocab_size, 87 | is_encoder=args.is_encoder, 88 | ) 89 | -------------------------------------------------------------------------------- /mytransformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .metrics import glue_compute_metrics, xnli_compute_metrics 20 | from .processors import ( 21 | DataProcessor, 22 | InputExample, 23 | InputFeatures, 24 | SingleSentenceClassificationProcessor, 25 | SquadExample, 26 | SquadFeatures, 27 | SquadV1Processor, 28 | SquadV2Processor, 29 | glue_convert_examples_to_features, 30 | glue_output_modes, 31 | glue_processors, 32 | glue_tasks_num_labels, 33 | squad_convert_examples_to_features, 34 | xnli_output_modes, 35 | xnli_processors, 36 | xnli_tasks_num_labels, 37 | ) 38 | -------------------------------------------------------------------------------- /mytransformers/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .glue import GlueDataset, GlueDataTrainingArguments 20 | from .language_modeling import ( 21 | LineByLineTextDataset, 22 | LineByLineWithRefDataset, 23 | LineByLineWithSOPTextDataset, 24 | TextDataset, 25 | TextDatasetForNextSentencePrediction, 26 | ) 27 | from .squad import SquadDataset, SquadDataTrainingArguments 28 | -------------------------------------------------------------------------------- /mytransformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels 20 | from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features 21 | from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor 22 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels 23 | -------------------------------------------------------------------------------- /mytransformers/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .file_utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | -------------------------------------------------------------------------------- /mytransformers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "black": "black>=20.8b1", 6 | "cookiecutter": "cookiecutter==1.7.2", 7 | "dataclasses": "dataclasses", 8 | "datasets": "datasets", 9 | "docutils": "docutils==0.16.0", 10 | "faiss-cpu": "faiss-cpu", 11 | "fastapi": "fastapi", 12 | "filelock": "filelock", 13 | "flake8": "flake8>=3.8.3", 14 | "flax": "flax>=0.3.2", 15 | "fugashi": "fugashi>=1.0", 16 | "importlib_metadata": "importlib_metadata", 17 | "ipadic": "ipadic>=1.0.0,<2.0", 18 | "isort": "isort>=5.5.4", 19 | "jax": "jax>=0.2.8", 20 | "jaxlib": "jaxlib>=0.1.59", 21 | "keras2onnx": "keras2onnx", 22 | "numpy": "numpy>=1.17", 23 | "onnxconverter-common": "onnxconverter-common", 24 | "onnxruntime-tools": "onnxruntime-tools>=1.4.2", 25 | "onnxruntime": "onnxruntime>=1.4.0", 26 | "packaging": "packaging", 27 | "parameterized": "parameterized", 28 | "Pillow": "Pillow", 29 | "protobuf": "protobuf", 30 | "psutil": "psutil", 31 | "pydantic": "pydantic", 32 | "pytest": "pytest", 33 | "pytest-subtests": "pytest-subtests", 34 | "pytest-sugar": "pytest-sugar", 35 | "pytest-xdist": "pytest-xdist", 36 | "python": "python>=3.6.0", 37 | "recommonmark": "recommonmark", 38 | "regex": "regex!=2019.12.17", 39 | "requests": "requests", 40 | "sacremoses": "sacremoses", 41 | "scikit-learn": "scikit-learn", 42 | "sentencepiece": "sentencepiece==0.1.91", 43 | "soundfile": "soundfile", 44 | "sphinx-copybutton": "sphinx-copybutton", 45 | "sphinx-markdown-tables": "sphinx-markdown-tables", 46 | "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", 47 | "sphinx": "sphinx==3.2.1", 48 | "starlette": "starlette", 49 | "tensorflow-cpu": "tensorflow-cpu>=2.3", 50 | "tensorflow": "tensorflow>=2.3", 51 | "timeout-decorator": "timeout-decorator", 52 | "tokenizers": "tokenizers>=0.10.1,<0.11", 53 | "torch": "torch>=1.0", 54 | "torchaudio": "torchaudio", 55 | "tqdm": "tqdm>=4.27", 56 | "unidic": "unidic>=1.0.2", 57 | "unidic_lite": "unidic_lite>=1.0.7", 58 | "uvicorn": "uvicorn", 59 | "sagemaker": "sagemaker>=2.31.0", 60 | } 61 | -------------------------------------------------------------------------------- /mytransformers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from . import ( 20 | albert, 21 | auto, 22 | bart, 23 | barthez, 24 | bert, 25 | bert_generation, 26 | bert_japanese, 27 | bertweet, 28 | big_bird, 29 | blenderbot, 30 | blenderbot_small, 31 | camembert, 32 | convbert, 33 | ctrl, 34 | deberta, 35 | dialogpt, 36 | distilbert, 37 | dpr, 38 | electra, 39 | encoder_decoder, 40 | flaubert, 41 | fsmt, 42 | funnel, 43 | gpt2, 44 | gpt_neo, 45 | herbert, 46 | layoutlm, 47 | led, 48 | longformer, 49 | lxmert, 50 | m2m_100, 51 | marian, 52 | mbart, 53 | mmbt, 54 | mobilebert, 55 | mpnet, 56 | mt5, 57 | openai, 58 | pegasus, 59 | phobert, 60 | prophetnet, 61 | rag, 62 | reformer, 63 | retribert, 64 | roberta, 65 | speech_to_text, 66 | squeezebert, 67 | t5, 68 | tapas, 69 | transfo_xl, 70 | vit, 71 | wav2vec2, 72 | xlm, 73 | xlm_roberta, 74 | xlnet, 75 | ) 76 | -------------------------------------------------------------------------------- /mytransformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ALBERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = AlbertConfig.from_json_file(albert_config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = AlbertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_albert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--albert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained ALBERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /mytransformers/models/bart/tokenization_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ...utils import logging 17 | from ..roberta.tokenization_roberta import RobertaTokenizer 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} 24 | 25 | # See all BART models at https://huggingface.co/models?filter=bart 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/vocab.json", 29 | "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/vocab.json", 30 | "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json", 31 | "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json", 32 | "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/vocab.json", 33 | "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/vocab.json", 34 | }, 35 | "merges_file": { 36 | "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/merges.txt", 37 | "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/merges.txt", 38 | "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt", 39 | "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt", 40 | "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/merges.txt", 41 | "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/merges.txt", 42 | }, 43 | } 44 | 45 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 46 | "facebook/bart-base": 1024, 47 | "facebook/bart-large": 1024, 48 | "facebook/bart-large-mnli": 1024, 49 | "facebook/bart-large-cnn": 1024, 50 | "facebook/bart-large-xsum": 1024, 51 | "yjernite/bart_eli5": 1024, 52 | } 53 | 54 | 55 | class BartTokenizer(RobertaTokenizer): 56 | r""" 57 | Construct a BART tokenizer. 58 | 59 | :class:`~transformers.BartTokenizer` is identical to :class:`~transformers.RobertaTokenizer`. Refer to superclass 60 | :class:`~transformers.RobertaTokenizer` for usage examples and documentation concerning the initialization 61 | parameters and other methods. 62 | """ 63 | vocab_files_names = VOCAB_FILES_NAMES 64 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 65 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 66 | -------------------------------------------------------------------------------- /mytransformers/models/barthez/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_tokenizers_available 22 | 23 | 24 | _import_structure = {} 25 | 26 | if is_sentencepiece_available(): 27 | _import_structure["tokenization_barthez"] = ["BarthezTokenizer"] 28 | 29 | if is_tokenizers_available(): 30 | _import_structure["tokenization_barthez_fast"] = ["BarthezTokenizerFast"] 31 | 32 | 33 | if TYPE_CHECKING: 34 | 35 | if is_sentencepiece_available(): 36 | from .tokenization_barthez import BarthezTokenizer 37 | 38 | if is_tokenizers_available(): 39 | from .tokenization_barthez_fast import BarthezTokenizerFast 40 | 41 | else: 42 | import importlib 43 | import os 44 | import sys 45 | 46 | class _LazyModule(_BaseLazyModule): 47 | """ 48 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 49 | """ 50 | 51 | __file__ = globals()["__file__"] 52 | __path__ = [os.path.dirname(__file__)] 53 | 54 | def _get_module(self, module_name: str): 55 | return importlib.import_module("." + module_name, self.__name__) 56 | 57 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 58 | -------------------------------------------------------------------------------- /mytransformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--bert_config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained BERT model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /mytransformers/models/bert_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_bert_generation": ["BertGenerationConfig"], 26 | } 27 | 28 | if is_sentencepiece_available(): 29 | _import_structure["tokenization_bert_generation"] = ["BertGenerationTokenizer"] 30 | 31 | if is_torch_available(): 32 | _import_structure["modeling_bert_generation"] = [ 33 | "BertGenerationDecoder", 34 | "BertGenerationEncoder", 35 | "load_tf_weights_in_bert_generation", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_bert_generation import BertGenerationConfig 41 | 42 | if is_sentencepiece_available(): 43 | from .tokenization_bert_generation import BertGenerationTokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_bert_generation import ( 47 | BertGenerationDecoder, 48 | BertGenerationEncoder, 49 | load_tf_weights_in_bert_generation, 50 | ) 51 | 52 | else: 53 | import importlib 54 | import os 55 | import sys 56 | 57 | class _LazyModule(_BaseLazyModule): 58 | """ 59 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 60 | """ 61 | 62 | __file__ = globals()["__file__"] 63 | __path__ = [os.path.dirname(__file__)] 64 | 65 | def _get_module(self, module_name: str): 66 | return importlib.import_module("." + module_name, self.__name__) 67 | 68 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 69 | -------------------------------------------------------------------------------- /mytransformers/models/bert_japanese/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /mytransformers/models/bertweet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_bertweet": ["BertweetTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_bertweet import BertweetTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /mytransformers/models/big_bird/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"], 25 | "tokenization_big_bird": ["BigBirdTokenizer"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_big_bird"] = [ 30 | "BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST", 31 | "BigBirdForCausalLM", 32 | "BigBirdForMaskedLM", 33 | "BigBirdForMultipleChoice", 34 | "BigBirdForPreTraining", 35 | "BigBirdForQuestionAnswering", 36 | "BigBirdForSequenceClassification", 37 | "BigBirdForTokenClassification", 38 | "BigBirdLayer", 39 | "BigBirdModel", 40 | "BigBirdPreTrainedModel", 41 | "load_tf_weights_in_big_bird", 42 | ] 43 | 44 | 45 | if TYPE_CHECKING: 46 | from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig 47 | from .tokenization_big_bird import BigBirdTokenizer 48 | 49 | if is_torch_available(): 50 | from .modeling_big_bird import ( 51 | BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST, 52 | BigBirdForCausalLM, 53 | BigBirdForMaskedLM, 54 | BigBirdForMultipleChoice, 55 | BigBirdForPreTraining, 56 | BigBirdForQuestionAnswering, 57 | BigBirdForSequenceClassification, 58 | BigBirdForTokenClassification, 59 | BigBirdLayer, 60 | BigBirdModel, 61 | BigBirdPreTrainedModel, 62 | load_tf_weights_in_big_bird, 63 | ) 64 | 65 | 66 | else: 67 | import importlib 68 | import os 69 | import sys 70 | 71 | class _LazyModule(_BaseLazyModule): 72 | """ 73 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 74 | """ 75 | 76 | __file__ = globals()["__file__"] 77 | __path__ = [os.path.dirname(__file__)] 78 | 79 | def _get_module(self, module_name: str): 80 | return importlib.import_module("." + module_name, self.__name__) 81 | 82 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 83 | -------------------------------------------------------------------------------- /mytransformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BigBird checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird 21 | from transformers.utils import logging 22 | 23 | 24 | logging.set_verbosity_info() 25 | 26 | 27 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): 28 | # Initialise PyTorch model 29 | config = BigBirdConfig.from_json_file(big_bird_config_file) 30 | print(f"Building PyTorch model from configuration: {config}") 31 | 32 | if is_trivia_qa: 33 | model = BigBirdForQuestionAnswering(config) 34 | else: 35 | model = BigBirdForPreTraining(config) 36 | 37 | # Load weights from tf checkpoint 38 | load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) 39 | 40 | # Save pytorch-model 41 | print(f"Save PyTorch model to {pytorch_dump_path}") 42 | model.save_pretrained(pytorch_dump_path) 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | parser = argparse.ArgumentParser() 48 | # Required parameters 49 | parser.add_argument( 50 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 51 | ) 52 | parser.add_argument( 53 | "--big_bird_config_file", 54 | default=None, 55 | type=str, 56 | required=True, 57 | help="The config json file corresponding to the pre-trained BERT model. \n" 58 | "This specifies the model architecture.", 59 | ) 60 | parser.add_argument( 61 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 62 | ) 63 | parser.add_argument( 64 | "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head." 65 | ) 66 | args = parser.parse_args() 67 | convert_tf_checkpoint_to_pytorch( 68 | args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa 69 | ) 70 | -------------------------------------------------------------------------------- /mytransformers/models/blenderbot/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig"], 26 | "tokenization_blenderbot": ["BlenderbotTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_blenderbot"] = [ 31 | "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "BlenderbotForCausalLM", 33 | "BlenderbotForConditionalGeneration", 34 | "BlenderbotModel", 35 | "BlenderbotPreTrainedModel", 36 | ] 37 | 38 | 39 | if is_tf_available(): 40 | _import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"] 41 | 42 | 43 | if TYPE_CHECKING: 44 | from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig 45 | from .tokenization_blenderbot import BlenderbotTokenizer 46 | 47 | if is_torch_available(): 48 | from .modeling_blenderbot import ( 49 | BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, 50 | BlenderbotForCausalLM, 51 | BlenderbotForConditionalGeneration, 52 | BlenderbotModel, 53 | BlenderbotPreTrainedModel, 54 | ) 55 | 56 | if is_tf_available(): 57 | from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel 58 | 59 | else: 60 | import importlib 61 | import os 62 | import sys 63 | 64 | class _LazyModule(_BaseLazyModule): 65 | """ 66 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 67 | """ 68 | 69 | __file__ = globals()["__file__"] 70 | __path__ = [os.path.dirname(__file__)] 71 | 72 | def _get_module(self, module_name: str): 73 | return importlib.import_module("." + module_name, self.__name__) 74 | 75 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 76 | -------------------------------------------------------------------------------- /mytransformers/models/blenderbot_small/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_blenderbot_small": ["BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotSmallConfig"], 25 | "tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_blenderbot_small"] = [ 30 | "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", 31 | "BlenderbotSmallForCausalLM", 32 | "BlenderbotSmallForConditionalGeneration", 33 | "BlenderbotSmallModel", 34 | "BlenderbotSmallPreTrainedModel", 35 | ] 36 | 37 | if is_tf_available(): 38 | _import_structure["modeling_tf_blenderbot_small"] = [ 39 | "TFBlenderbotSmallForConditionalGeneration", 40 | "TFBlenderbotSmallModel", 41 | ] 42 | 43 | if TYPE_CHECKING: 44 | from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig 45 | from .tokenization_blenderbot_small import BlenderbotSmallTokenizer 46 | 47 | if is_torch_available(): 48 | from .modeling_blenderbot_small import ( 49 | BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, 50 | BlenderbotSmallForCausalLM, 51 | BlenderbotSmallForConditionalGeneration, 52 | BlenderbotSmallModel, 53 | BlenderbotSmallPreTrainedModel, 54 | ) 55 | 56 | if is_tf_available(): 57 | from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel 58 | 59 | else: 60 | import importlib 61 | import os 62 | import sys 63 | 64 | class _LazyModule(_BaseLazyModule): 65 | """ 66 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 67 | """ 68 | 69 | __file__ = globals()["__file__"] 70 | __path__ = [os.path.dirname(__file__)] 71 | 72 | def _get_module(self, module_name: str): 73 | return importlib.import_module("." + module_name, self.__name__) 74 | 75 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 76 | -------------------------------------------------------------------------------- /mytransformers/models/camembert/configuration_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ CamemBERT configuration """ 17 | 18 | from ...utils import logging 19 | from ..roberta.configuration_roberta import RobertaConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "camembert-base": "https://huggingface.co/camembert-base/resolve/main/config.json", 26 | "umberto-commoncrawl-cased-v1": "https://huggingface.co/Musixmatch/umberto-commoncrawl-cased-v1/resolve/main/config.json", 27 | "umberto-wikipedia-uncased-v1": "https://huggingface.co/Musixmatch/umberto-wikipedia-uncased-v1/resolve/main/config.json", 28 | } 29 | 30 | 31 | class CamembertConfig(RobertaConfig): 32 | """ 33 | This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate 34 | documentation alongside usage examples. 35 | """ 36 | 37 | model_type = "camembert" 38 | -------------------------------------------------------------------------------- /mytransformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ConvBERT checkpoint.""" 16 | 17 | import argparse 18 | 19 | from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert 20 | from transformers.utils import logging 21 | 22 | 23 | logging.set_verbosity_info() 24 | 25 | 26 | def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): 27 | conf = ConvBertConfig.from_json_file(convbert_config_file) 28 | model = ConvBertModel(conf) 29 | 30 | model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) 31 | model.save_pretrained(pytorch_dump_path) 32 | 33 | tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) 34 | tf_model.save_pretrained(pytorch_dump_path) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | # Required parameters 40 | parser.add_argument( 41 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 42 | ) 43 | parser.add_argument( 44 | "--convbert_config_file", 45 | default=None, 46 | type=str, 47 | required=True, 48 | help="The config json file corresponding to the pre-trained ConvBERT model. \n" 49 | "This specifies the model architecture.", 50 | ) 51 | parser.add_argument( 52 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 53 | ) 54 | args = parser.parse_args() 55 | convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) 56 | -------------------------------------------------------------------------------- /mytransformers/models/convbert/tokenization_convbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for ConvBERT.""" 16 | from ...utils import logging 17 | from ..bert.tokenization_bert import BertTokenizer 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 23 | 24 | PRETRAINED_VOCAB_FILES_MAP = { 25 | "vocab_file": { 26 | "YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/vocab.txt", 27 | "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt", 28 | "YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/vocab.txt", 29 | } 30 | } 31 | 32 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 33 | "YituTech/conv-bert-base": 512, 34 | "YituTech/conv-bert-medium-small": 512, 35 | "YituTech/conv-bert-small": 512, 36 | } 37 | 38 | 39 | PRETRAINED_INIT_CONFIGURATION = { 40 | "YituTech/conv-bert-base": {"do_lower_case": True}, 41 | "YituTech/conv-bert-medium-small": {"do_lower_case": True}, 42 | "YituTech/conv-bert-small": {"do_lower_case": True}, 43 | } 44 | 45 | 46 | class ConvBertTokenizer(BertTokenizer): 47 | r""" 48 | Construct a ConvBERT tokenizer. :class:`~transformers.ConvBertTokenizer` is identical to 49 | :class:`~transformers.BertTokenizer` and runs end-to-end tokenization: punctuation splitting and wordpiece. Refer 50 | to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning parameters. 51 | """ 52 | 53 | vocab_files_names = VOCAB_FILES_NAMES 54 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 55 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 56 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 57 | -------------------------------------------------------------------------------- /mytransformers/models/convbert/tokenization_convbert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for ConvBERT.""" 16 | from ...utils import logging 17 | from ..bert.tokenization_bert_fast import BertTokenizerFast 18 | from .tokenization_convbert import ConvBertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/vocab.txt", 28 | "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt", 29 | "YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/vocab.txt", 30 | } 31 | } 32 | 33 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 34 | "YituTech/conv-bert-base": 512, 35 | "YituTech/conv-bert-medium-small": 512, 36 | "YituTech/conv-bert-small": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "YituTech/conv-bert-base": {"do_lower_case": True}, 42 | "YituTech/conv-bert-medium-small": {"do_lower_case": True}, 43 | "YituTech/conv-bert-small": {"do_lower_case": True}, 44 | } 45 | 46 | 47 | class ConvBertTokenizerFast(BertTokenizerFast): 48 | r""" 49 | Construct a "fast" ConvBERT tokenizer (backed by HuggingFace's `tokenizers` library). 50 | 51 | :class:`~transformers.ConvBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 52 | end-to-end tokenization: punctuation splitting and wordpiece. 53 | 54 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 55 | parameters. 56 | """ 57 | vocab_files_names = VOCAB_FILES_NAMES 58 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 59 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 60 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 61 | slow_tokenizer_class = ConvBertTokenizer 62 | -------------------------------------------------------------------------------- /mytransformers/models/ctrl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig"], 26 | "tokenization_ctrl": ["CTRLTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_ctrl"] = [ 31 | "CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "CTRLForSequenceClassification", 33 | "CTRLLMHeadModel", 34 | "CTRLModel", 35 | "CTRLPreTrainedModel", 36 | ] 37 | 38 | if is_tf_available(): 39 | _import_structure["modeling_tf_ctrl"] = [ 40 | "TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", 41 | "TFCTRLForSequenceClassification", 42 | "TFCTRLLMHeadModel", 43 | "TFCTRLModel", 44 | "TFCTRLPreTrainedModel", 45 | ] 46 | 47 | 48 | if TYPE_CHECKING: 49 | from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig 50 | from .tokenization_ctrl import CTRLTokenizer 51 | 52 | if is_torch_available(): 53 | from .modeling_ctrl import ( 54 | CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, 55 | CTRLForSequenceClassification, 56 | CTRLLMHeadModel, 57 | CTRLModel, 58 | CTRLPreTrainedModel, 59 | ) 60 | 61 | if is_tf_available(): 62 | from .modeling_tf_ctrl import ( 63 | TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, 64 | TFCTRLForSequenceClassification, 65 | TFCTRLLMHeadModel, 66 | TFCTRLModel, 67 | TFCTRLPreTrainedModel, 68 | ) 69 | 70 | else: 71 | import importlib 72 | import os 73 | import sys 74 | 75 | class _LazyModule(_BaseLazyModule): 76 | """ 77 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 78 | """ 79 | 80 | __file__ = globals()["__file__"] 81 | __path__ = [os.path.dirname(__file__)] 82 | 83 | def _get_module(self, module_name: str): 84 | return importlib.import_module("." + module_name, self.__name__) 85 | 86 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 87 | -------------------------------------------------------------------------------- /mytransformers/models/deberta/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig"], 26 | "tokenization_deberta": ["DebertaTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_deberta"] = [ 31 | "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "DebertaForMaskedLM", 33 | "DebertaForQuestionAnswering", 34 | "DebertaForSequenceClassification", 35 | "DebertaForTokenClassification", 36 | "DebertaModel", 37 | "DebertaPreTrainedModel", 38 | ] 39 | 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig 43 | from .tokenization_deberta import DebertaTokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_deberta import ( 47 | DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | DebertaForMaskedLM, 49 | DebertaForQuestionAnswering, 50 | DebertaForSequenceClassification, 51 | DebertaForTokenClassification, 52 | DebertaModel, 53 | DebertaPreTrainedModel, 54 | ) 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /mytransformers/models/deberta_v2/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], 26 | "tokenization_deberta_v2": ["DebertaV2Tokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_deberta_v2"] = [ 31 | "DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "DebertaV2ForMaskedLM", 33 | "DebertaV2ForQuestionAnswering", 34 | "DebertaV2ForSequenceClassification", 35 | "DebertaV2ForTokenClassification", 36 | "DebertaV2Model", 37 | "DebertaV2PreTrainedModel", 38 | ] 39 | 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config 43 | from .tokenization_deberta_v2 import DebertaV2Tokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_deberta_v2 import ( 47 | DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | DebertaV2ForMaskedLM, 49 | DebertaV2ForQuestionAnswering, 50 | DebertaV2ForSequenceClassification, 51 | DebertaV2ForTokenClassification, 52 | DebertaV2Model, 53 | DebertaV2PreTrainedModel, 54 | ) 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /mytransformers/models/dialogpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SALT-NLP/Adaptive-Compositional-Modules/357aa2d6d1cd97ea03aeaddbd5372a1aeecbbe4c/mytransformers/models/dialogpt/__init__.py -------------------------------------------------------------------------------- /mytransformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import torch 19 | 20 | from transformers.file_utils import WEIGHTS_NAME 21 | 22 | 23 | DIALOGPT_MODELS = ["small", "medium", "large"] 24 | 25 | OLD_KEY = "lm_head.decoder.weight" 26 | NEW_KEY = "lm_head.weight" 27 | 28 | 29 | def convert_dialogpt_checkpoint(checkpoint_path: str, pytorch_dump_folder_path: str): 30 | d = torch.load(checkpoint_path) 31 | d[NEW_KEY] = d.pop(OLD_KEY) 32 | os.makedirs(pytorch_dump_folder_path, exist_ok=True) 33 | torch.save(d, os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--dialogpt_path", default=".", type=str) 39 | args = parser.parse_args() 40 | for MODEL in DIALOGPT_MODELS: 41 | checkpoint_path = os.path.join(args.dialogpt_path, f"{MODEL}_ft.pkl") 42 | pytorch_dump_folder_path = f"./DialoGPT-{MODEL}" 43 | convert_dialogpt_checkpoint( 44 | checkpoint_path, 45 | pytorch_dump_folder_path, 46 | ) 47 | -------------------------------------------------------------------------------- /mytransformers/models/distilbert/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert import BertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "distilbert-base-uncased": "https://huggingface.co/distilbert-base-uncased/resolve/main/vocab.txt", 28 | "distilbert-base-uncased-distilled-squad": "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/vocab.txt", 29 | "distilbert-base-cased": "https://huggingface.co/distilbert-base-cased/resolve/main/vocab.txt", 30 | "distilbert-base-cased-distilled-squad": "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/vocab.txt", 31 | "distilbert-base-german-cased": "https://huggingface.co/distilbert-base-german-cased/resolve/main/vocab.txt", 32 | "distilbert-base-multilingual-cased": "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/vocab.txt", 33 | } 34 | } 35 | 36 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 37 | "distilbert-base-uncased": 512, 38 | "distilbert-base-uncased-distilled-squad": 512, 39 | "distilbert-base-cased": 512, 40 | "distilbert-base-cased-distilled-squad": 512, 41 | "distilbert-base-german-cased": 512, 42 | "distilbert-base-multilingual-cased": 512, 43 | } 44 | 45 | 46 | PRETRAINED_INIT_CONFIGURATION = { 47 | "distilbert-base-uncased": {"do_lower_case": True}, 48 | "distilbert-base-uncased-distilled-squad": {"do_lower_case": True}, 49 | "distilbert-base-cased": {"do_lower_case": False}, 50 | "distilbert-base-cased-distilled-squad": {"do_lower_case": False}, 51 | "distilbert-base-german-cased": {"do_lower_case": False}, 52 | "distilbert-base-multilingual-cased": {"do_lower_case": False}, 53 | } 54 | 55 | 56 | class DistilBertTokenizer(BertTokenizer): 57 | r""" 58 | Construct a DistilBERT tokenizer. 59 | 60 | :class:`~transformers.DistilBertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 61 | tokenization: punctuation splitting and wordpiece. 62 | 63 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 64 | parameters. 65 | """ 66 | 67 | vocab_files_names = VOCAB_FILES_NAMES 68 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 69 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 70 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 71 | model_input_names = ["input_ids", "attention_mask"] 72 | -------------------------------------------------------------------------------- /mytransformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert ELECTRA checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): 30 | # Initialise PyTorch model 31 | config = ElectraConfig.from_json_file(config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | 34 | if discriminator_or_generator == "discriminator": 35 | model = ElectraForPreTraining(config) 36 | elif discriminator_or_generator == "generator": 37 | model = ElectraForMaskedLM(config) 38 | else: 39 | raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'") 40 | 41 | # Load weights from tf checkpoint 42 | load_tf_weights_in_electra( 43 | model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator 44 | ) 45 | 46 | # Save pytorch-model 47 | print(f"Save PyTorch model to {pytorch_dump_path}") 48 | torch.save(model.state_dict(), pytorch_dump_path) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # Required parameters 54 | parser.add_argument( 55 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 56 | ) 57 | parser.add_argument( 58 | "--config_file", 59 | default=None, 60 | type=str, 61 | required=True, 62 | help="The config json file corresponding to the pre-trained model. \n" 63 | "This specifies the model architecture.", 64 | ) 65 | parser.add_argument( 66 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 67 | ) 68 | parser.add_argument( 69 | "--discriminator_or_generator", 70 | default=None, 71 | type=str, 72 | required=True, 73 | help="Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or " 74 | "'generator'.", 75 | ) 76 | args = parser.parse_args() 77 | convert_tf_checkpoint_to_pytorch( 78 | args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator 79 | ) 80 | -------------------------------------------------------------------------------- /mytransformers/models/electra/tokenization_electra.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ..bert.tokenization_bert import BertTokenizer 17 | 18 | 19 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 20 | 21 | PRETRAINED_VOCAB_FILES_MAP = { 22 | "vocab_file": { 23 | "google/electra-small-generator": "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt", 24 | "google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt", 25 | "google/electra-large-generator": "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt", 26 | "google/electra-small-discriminator": "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt", 27 | "google/electra-base-discriminator": "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt", 28 | "google/electra-large-discriminator": "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt", 29 | } 30 | } 31 | 32 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 33 | "google/electra-small-generator": 512, 34 | "google/electra-base-generator": 512, 35 | "google/electra-large-generator": 512, 36 | "google/electra-small-discriminator": 512, 37 | "google/electra-base-discriminator": 512, 38 | "google/electra-large-discriminator": 512, 39 | } 40 | 41 | 42 | PRETRAINED_INIT_CONFIGURATION = { 43 | "google/electra-small-generator": {"do_lower_case": True}, 44 | "google/electra-base-generator": {"do_lower_case": True}, 45 | "google/electra-large-generator": {"do_lower_case": True}, 46 | "google/electra-small-discriminator": {"do_lower_case": True}, 47 | "google/electra-base-discriminator": {"do_lower_case": True}, 48 | "google/electra-large-discriminator": {"do_lower_case": True}, 49 | } 50 | 51 | 52 | class ElectraTokenizer(BertTokenizer): 53 | r""" 54 | Construct an ELECTRA tokenizer. 55 | 56 | :class:`~transformers.ElectraTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 57 | tokenization: punctuation splitting and wordpiece. 58 | 59 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 60 | parameters. 61 | """ 62 | 63 | vocab_files_names = VOCAB_FILES_NAMES 64 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 65 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 66 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 67 | -------------------------------------------------------------------------------- /mytransformers/models/encoder_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_encoder_decoder": ["EncoderDecoderConfig"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"] 30 | 31 | 32 | if TYPE_CHECKING: 33 | from .configuration_encoder_decoder import EncoderDecoderConfig 34 | 35 | if is_torch_available(): 36 | from .modeling_encoder_decoder import EncoderDecoderModel 37 | 38 | else: 39 | import importlib 40 | import os 41 | import sys 42 | 43 | class _LazyModule(_BaseLazyModule): 44 | """ 45 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 46 | """ 47 | 48 | __file__ = globals()["__file__"] 49 | __path__ = [os.path.dirname(__file__)] 50 | 51 | def _get_module(self, module_name: str): 52 | return importlib.import_module("." + module_name, self.__name__) 53 | 54 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 55 | -------------------------------------------------------------------------------- /mytransformers/models/fsmt/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig"], 26 | "tokenization_fsmt": ["FSMTTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_fsmt"] = ["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"] 31 | 32 | 33 | if TYPE_CHECKING: 34 | from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig 35 | from .tokenization_fsmt import FSMTTokenizer 36 | 37 | if is_torch_available(): 38 | from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel 39 | 40 | else: 41 | import importlib 42 | import os 43 | import sys 44 | 45 | class _LazyModule(_BaseLazyModule): 46 | """ 47 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 48 | """ 49 | 50 | __file__ = globals()["__file__"] 51 | __path__ = [os.path.dirname(__file__)] 52 | 53 | def _get_module(self, module_name: str): 54 | return importlib.import_module("." + module_name, self.__name__) 55 | 56 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 57 | -------------------------------------------------------------------------------- /mytransformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Funnel checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import FunnelBaseModel, FunnelConfig, FunnelModel, load_tf_weights_in_funnel 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model): 30 | # Initialise PyTorch model 31 | config = FunnelConfig.from_json_file(config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = FunnelBaseModel(config) if base_model else FunnelModel(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_funnel(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | parser.add_argument( 61 | "--base_model", action="store_true", help="Whether you want just the base model (no decoder) or not." 62 | ) 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch( 65 | args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.base_model 66 | ) 67 | -------------------------------------------------------------------------------- /mytransformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 23 | from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME 24 | from transformers.utils import logging 25 | 26 | 27 | logging.set_verbosity_info() 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config.from_json_file(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 44 | print(f"Save PyTorch model to {pytorch_weights_dump_path}") 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print(f"Save configuration file to {pytorch_config_dump_path}") 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # Required parameters 54 | parser.add_argument( 55 | "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | parser.add_argument( 61 | "--gpt2_config_file", 62 | default="", 63 | type=str, 64 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 65 | "This specifies the model architecture.", 66 | ) 67 | args = parser.parse_args() 68 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path) 69 | -------------------------------------------------------------------------------- /mytransformers/models/gpt_neo/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], 25 | } 26 | 27 | if is_torch_available(): 28 | _import_structure["modeling_gpt_neo"] = [ 29 | "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", 30 | "GPTNeoForCausalLM", 31 | "GPTNeoModel", 32 | "GPTNeoPreTrainedModel", 33 | "load_tf_weights_in_gpt_neo", 34 | ] 35 | 36 | 37 | if TYPE_CHECKING: 38 | from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig 39 | 40 | if is_torch_available(): 41 | from .modeling_gpt_neo import ( 42 | GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, 43 | GPTNeoForCausalLM, 44 | GPTNeoModel, 45 | GPTNeoPreTrainedModel, 46 | load_tf_weights_in_gpt_neo, 47 | ) 48 | 49 | 50 | else: 51 | import importlib 52 | import os 53 | import sys 54 | 55 | class _LazyModule(_BaseLazyModule): 56 | """ 57 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 58 | """ 59 | 60 | __file__ = globals()["__file__"] 61 | __path__ = [os.path.dirname(__file__)] 62 | 63 | def _get_module(self, module_name: str): 64 | return importlib.import_module("." + module_name, self.__name__) 65 | 66 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 67 | -------------------------------------------------------------------------------- /mytransformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Eleuther AI and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert GPT Neo checkpoint.""" 16 | 17 | 18 | import argparse 19 | import json 20 | 21 | from transformers import GPTNeoConfig, GPTNeoForCausalLM, load_tf_weights_in_gpt_neo 22 | from transformers.utils import logging 23 | 24 | 25 | logging.set_verbosity_info() 26 | 27 | 28 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 29 | # Initialise PyTorch model 30 | config_json = json.load(open(config_file, "r")) 31 | config = GPTNeoConfig( 32 | hidden_size=config_json["n_embd"], 33 | num_layers=config_json["n_layer"], 34 | num_heads=config_json["n_head"], 35 | attention_types=config_json["attention_types"], 36 | max_position_embeddings=config_json["n_ctx"], 37 | resid_dropout=config_json["res_dropout"], 38 | embed_dropout=config_json["embed_dropout"], 39 | attention_dropout=config_json["attn_dropout"], 40 | ) 41 | print(f"Building PyTorch model from configuration: {config}") 42 | model = GPTNeoForCausalLM(config) 43 | 44 | # Load weights from tf checkpoint 45 | load_tf_weights_in_gpt_neo(model, config, tf_checkpoint_path) 46 | 47 | # Save pytorch-model 48 | print(f"Save PyTorch model to {pytorch_dump_path}") 49 | model.save_pretrained(pytorch_dump_path) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | # Required parameters 55 | parser.add_argument( 56 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 57 | ) 58 | parser.add_argument( 59 | "--config_file", 60 | default=None, 61 | type=str, 62 | required=True, 63 | help="The config json file corresponding to the pre-trained mesh-tf model. \n" 64 | "This specifies the model architecture.", 65 | ) 66 | parser.add_argument( 67 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 68 | ) 69 | args = parser.parse_args() 70 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 71 | -------------------------------------------------------------------------------- /mytransformers/models/herbert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tokenizers_available 22 | 23 | 24 | _import_structure = { 25 | "tokenization_herbert": ["HerbertTokenizer"], 26 | } 27 | 28 | if is_tokenizers_available(): 29 | _import_structure["tokenization_herbert_fast"] = ["HerbertTokenizerFast"] 30 | 31 | 32 | if TYPE_CHECKING: 33 | from .tokenization_herbert import HerbertTokenizer 34 | 35 | if is_tokenizers_available(): 36 | from .tokenization_herbert_fast import HerbertTokenizerFast 37 | 38 | else: 39 | import importlib 40 | import os 41 | import sys 42 | 43 | class _LazyModule(_BaseLazyModule): 44 | """ 45 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 46 | """ 47 | 48 | __file__ = globals()["__file__"] 49 | __path__ = [os.path.dirname(__file__)] 50 | 51 | def _get_module(self, module_name: str): 52 | return importlib.import_module("." + module_name, self.__name__) 53 | 54 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 55 | -------------------------------------------------------------------------------- /mytransformers/models/herbert/tokenization_herbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Language Team Authors, Allegro.pl, Facebook Inc. and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ...utils import logging 17 | from ..bert.tokenization_bert import BasicTokenizer 18 | from ..xlm.tokenization_xlm import XLMTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = { 24 | "vocab_file": "vocab.json", 25 | "merges_file": "merges.txt", 26 | } 27 | 28 | PRETRAINED_VOCAB_FILES_MAP = { 29 | "vocab_file": { 30 | "allegro/herbert-base-cased": "https://huggingface.co/allegro/herbert-base-cased/resolve/main/vocab.json" 31 | }, 32 | "merges_file": { 33 | "allegro/herbert-base-cased": "https://huggingface.co/allegro/herbert-base-cased/resolve/main/merges.txt" 34 | }, 35 | } 36 | 37 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"allegro/herbert-base-cased": 514} 38 | PRETRAINED_INIT_CONFIGURATION = {} 39 | 40 | 41 | class HerbertTokenizer(XLMTokenizer): 42 | """ 43 | Construct a BPE tokenizer for HerBERT. 44 | 45 | Peculiarities: 46 | 47 | - uses BERT's pre-tokenizer: BaseTokenizer splits tokens on spaces, and also on punctuation. Each occurrence of a 48 | punctuation character will be treated separately. 49 | 50 | - Such pretokenized input is BPE subtokenized 51 | 52 | This tokenizer inherits from :class:`~transformers.XLMTokenizer` which contains most of the methods. Users should 53 | refer to the superclass for more information regarding methods. 54 | """ 55 | 56 | vocab_files_names = VOCAB_FILES_NAMES 57 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 58 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 59 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 60 | 61 | def __init__(self, **kwargs): 62 | 63 | kwargs["cls_token"] = "" 64 | kwargs["unk_token"] = "" 65 | kwargs["pad_token"] = "" 66 | kwargs["mask_token"] = "" 67 | kwargs["sep_token"] = "" 68 | kwargs["do_lowercase_and_remove_accent"] = False 69 | kwargs["additional_special_tokens"] = [] 70 | 71 | super().__init__(**kwargs) 72 | self.bert_pre_tokenizer = BasicTokenizer( 73 | do_lower_case=False, never_split=self.all_special_tokens, tokenize_chinese_chars=False, strip_accents=False 74 | ) 75 | 76 | def _tokenize(self, text): 77 | 78 | pre_tokens = self.bert_pre_tokenizer.tokenize(text) 79 | 80 | split_tokens = [] 81 | for token in pre_tokens: 82 | if token: 83 | split_tokens.extend([t for t in self.bpe(token).split(" ")]) 84 | 85 | return split_tokens 86 | -------------------------------------------------------------------------------- /mytransformers/models/ibert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_ibert"] = [ 30 | "IBERT_PRETRAINED_MODEL_ARCHIVE_LIST", 31 | "IBertForMaskedLM", 32 | "IBertForMultipleChoice", 33 | "IBertForQuestionAnswering", 34 | "IBertForSequenceClassification", 35 | "IBertForTokenClassification", 36 | "IBertModel", 37 | "IBertPreTrainedModel", 38 | ] 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig 42 | 43 | if is_torch_available(): 44 | from .modeling_ibert import ( 45 | IBERT_PRETRAINED_MODEL_ARCHIVE_LIST, 46 | IBertForMaskedLM, 47 | IBertForMultipleChoice, 48 | IBertForQuestionAnswering, 49 | IBertForSequenceClassification, 50 | IBertForTokenClassification, 51 | IBertModel, 52 | IBertPreTrainedModel, 53 | ) 54 | 55 | else: 56 | import importlib 57 | import os 58 | import sys 59 | 60 | class _LazyModule(_BaseLazyModule): 61 | """ 62 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 63 | """ 64 | 65 | __file__ = globals()["__file__"] 66 | __path__ = [os.path.dirname(__file__)] 67 | 68 | def _get_module(self, module_name: str): 69 | return importlib.import_module("." + module_name, self.__name__) 70 | 71 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 72 | -------------------------------------------------------------------------------- /mytransformers/models/layoutlm/tokenization_layoutlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization class for model LayoutLM.""" 16 | 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert import BertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt", 29 | "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt", 30 | } 31 | } 32 | 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 35 | "microsoft/layoutlm-base-uncased": 512, 36 | "microsoft/layoutlm-large-uncased": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "microsoft/layoutlm-base-uncased": {"do_lower_case": True}, 42 | "microsoft/layoutlm-large-uncased": {"do_lower_case": True}, 43 | } 44 | 45 | 46 | class LayoutLMTokenizer(BertTokenizer): 47 | r""" 48 | Constructs a LayoutLM tokenizer. 49 | 50 | :class:`~transformers.LayoutLMTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 51 | tokenization: punctuation splitting + wordpiece. 52 | 53 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 54 | parameters. 55 | """ 56 | 57 | vocab_files_names = VOCAB_FILES_NAMES 58 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 59 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 60 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 61 | -------------------------------------------------------------------------------- /mytransformers/models/layoutlm/tokenization_layoutlm_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization class for model LayoutLM.""" 16 | 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert_fast import BertTokenizerFast 20 | from .tokenization_layoutlm import LayoutLMTokenizer 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": { 29 | "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt", 30 | "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt", 31 | }, 32 | "tokenizer_file": { 33 | "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/tokenizer.json", 34 | "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/tokenizer.json", 35 | }, 36 | } 37 | 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | "microsoft/layoutlm-base-uncased": 512, 41 | "microsoft/layoutlm-large-uncased": 512, 42 | } 43 | 44 | 45 | PRETRAINED_INIT_CONFIGURATION = { 46 | "microsoft/layoutlm-base-uncased": {"do_lower_case": True}, 47 | "microsoft/layoutlm-large-uncased": {"do_lower_case": True}, 48 | } 49 | 50 | 51 | class LayoutLMTokenizerFast(BertTokenizerFast): 52 | r""" 53 | Constructs a "Fast" LayoutLMTokenizer. 54 | 55 | :class:`~transformers.LayoutLMTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 56 | end-to-end tokenization: punctuation splitting + wordpiece. 57 | 58 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 59 | parameters. 60 | """ 61 | 62 | vocab_files_names = VOCAB_FILES_NAMES 63 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 64 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 65 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 66 | slow_tokenizer_class = LayoutLMTokenizer 67 | -------------------------------------------------------------------------------- /mytransformers/models/led/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig"], 25 | "tokenization_led": ["LEDTokenizer"], 26 | } 27 | 28 | if is_tokenizers_available(): 29 | _import_structure["tokenization_led_fast"] = ["LEDTokenizerFast"] 30 | 31 | if is_torch_available(): 32 | _import_structure["modeling_led"] = [ 33 | "LED_PRETRAINED_MODEL_ARCHIVE_LIST", 34 | "LEDForConditionalGeneration", 35 | "LEDForQuestionAnswering", 36 | "LEDForSequenceClassification", 37 | "LEDModel", 38 | "LEDPreTrainedModel", 39 | ] 40 | 41 | 42 | if is_tf_available(): 43 | _import_structure["modeling_tf_led"] = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"] 44 | 45 | 46 | if TYPE_CHECKING: 47 | from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig 48 | from .tokenization_led import LEDTokenizer 49 | 50 | if is_tokenizers_available(): 51 | from .tokenization_led_fast import LEDTokenizerFast 52 | 53 | if is_torch_available(): 54 | from .modeling_led import ( 55 | LED_PRETRAINED_MODEL_ARCHIVE_LIST, 56 | LEDForConditionalGeneration, 57 | LEDForQuestionAnswering, 58 | LEDForSequenceClassification, 59 | LEDModel, 60 | LEDPreTrainedModel, 61 | ) 62 | 63 | if is_tf_available(): 64 | from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel 65 | 66 | else: 67 | import importlib 68 | import os 69 | import sys 70 | 71 | class _LazyModule(_BaseLazyModule): 72 | """ 73 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 74 | """ 75 | 76 | __file__ = globals()["__file__"] 77 | __path__ = [os.path.dirname(__file__)] 78 | 79 | def _get_module(self, module_name: str): 80 | return importlib.import_module("." + module_name, self.__name__) 81 | 82 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 83 | -------------------------------------------------------------------------------- /mytransformers/models/led/tokenization_led.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for LED.""" 16 | from ...utils import logging 17 | from ..bart.tokenization_bart import BartTokenizer 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | PRETRAINED_VOCAB_FILES_MAP = { 23 | "vocab_file": { 24 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json", 25 | }, 26 | "merges_file": { 27 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt", 28 | }, 29 | "tokenizer_file": { 30 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json", 31 | }, 32 | } 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 35 | "allenai/led-base-16384": 16384, 36 | } 37 | 38 | 39 | class LEDTokenizer(BartTokenizer): 40 | """ 41 | Construct a LED tokenizer. 42 | 43 | :class:`~transformers.LEDTokenizer` is identical to :class:`~transformers.BartTokenizer` and runs end-to-end 44 | tokenization: punctuation splitting and wordpiece. 45 | 46 | Refer to superclass :class:`~transformers.BartTokenizer` for usage examples and documentation concerning 47 | parameters. 48 | """ 49 | 50 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 51 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 52 | -------------------------------------------------------------------------------- /mytransformers/models/led/tokenization_led_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for LED.""" 16 | from ...utils import logging 17 | from ..bart.tokenization_bart_fast import BartTokenizerFast 18 | from .tokenization_led import LEDTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | PRETRAINED_VOCAB_FILES_MAP = { 24 | "vocab_file": { 25 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json", 26 | }, 27 | "merges_file": { 28 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt", 29 | }, 30 | "tokenizer_file": { 31 | "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json", 32 | }, 33 | } 34 | 35 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 36 | "allenai/led-base-16384": 16384, 37 | } 38 | 39 | 40 | class LEDTokenizerFast(BartTokenizerFast): 41 | r""" 42 | Construct a "fast" LED tokenizer (backed by HuggingFace's `tokenizers` library). 43 | 44 | :class:`~transformers.LEDTokenizerFast` is identical to :class:`~transformers.BartTokenizerFast` and runs 45 | end-to-end tokenization: punctuation splitting and wordpiece. 46 | 47 | Refer to superclass :class:`~transformers.BartTokenizerFast` for usage examples and documentation concerning 48 | parameters. 49 | """ 50 | 51 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 52 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 53 | slow_tokenizer_class = LEDTokenizer 54 | -------------------------------------------------------------------------------- /mytransformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert LXMERT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert 23 | from transformers.utils import logging 24 | 25 | 26 | logging.set_verbosity_info() 27 | 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = LxmertConfig.from_json_file(config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | model = LxmertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_lxmert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print(f"Save PyTorch model to {pytorch_dump_path}") 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | # Required parameters 46 | parser.add_argument( 47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 48 | ) 49 | parser.add_argument( 50 | "--config_file", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The config json file corresponding to the pre-trained model. \n" 55 | "This specifies the model architecture.", 56 | ) 57 | parser.add_argument( 58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 59 | ) 60 | args = parser.parse_args() 61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 62 | -------------------------------------------------------------------------------- /mytransformers/models/lxmert/tokenization_lxmert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ..bert.tokenization_bert import BertTokenizer 17 | 18 | 19 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 20 | 21 | PRETRAINED_VOCAB_FILES_MAP = { 22 | "vocab_file": { 23 | "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/vocab.txt", 24 | } 25 | } 26 | 27 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 28 | "unc-nlp/lxmert-base-uncased": 512, 29 | } 30 | 31 | PRETRAINED_INIT_CONFIGURATION = { 32 | "unc-nlp/lxmert-base-uncased": {"do_lower_case": True}, 33 | } 34 | 35 | 36 | class LxmertTokenizer(BertTokenizer): 37 | r""" 38 | Construct an LXMERT tokenizer. 39 | 40 | :class:`~transformers.LxmertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 41 | tokenization: punctuation splitting and wordpiece. 42 | 43 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 44 | parameters. 45 | """ 46 | 47 | vocab_files_names = VOCAB_FILES_NAMES 48 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 49 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 50 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 51 | -------------------------------------------------------------------------------- /mytransformers/models/lxmert/tokenization_lxmert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ..bert.tokenization_bert_fast import BertTokenizerFast 17 | from .tokenization_lxmert import LxmertTokenizer 18 | 19 | 20 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 21 | 22 | PRETRAINED_VOCAB_FILES_MAP = { 23 | "vocab_file": { 24 | "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/vocab.txt", 25 | }, 26 | "tokenizer_file": { 27 | "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/tokenizer.json", 28 | }, 29 | } 30 | 31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 32 | "unc-nlp/lxmert-base-uncased": 512, 33 | } 34 | 35 | PRETRAINED_INIT_CONFIGURATION = { 36 | "unc-nlp/lxmert-base-uncased": {"do_lower_case": True}, 37 | } 38 | 39 | 40 | class LxmertTokenizerFast(BertTokenizerFast): 41 | r""" 42 | Construct a "fast" LXMERT tokenizer (backed by HuggingFace's `tokenizers` library). 43 | 44 | :class:`~transformers.LxmertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 45 | end-to-end tokenization: punctuation splitting and wordpiece. 46 | 47 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 48 | parameters. 49 | """ 50 | vocab_files_names = VOCAB_FILES_NAMES 51 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 52 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 53 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 54 | slow_tokenizer_class = LxmertTokenizer 55 | -------------------------------------------------------------------------------- /mytransformers/models/m2m_100/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], 25 | "tokenization_m2m_100": ["M2M100Tokenizer"], 26 | } 27 | 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_m2m_100"] = [ 31 | "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "M2M100ForConditionalGeneration", 33 | "M2M100Model", 34 | "M2M100PreTrainedModel", 35 | ] 36 | 37 | 38 | if TYPE_CHECKING: 39 | from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config 40 | from .tokenization_m2m_100 import M2M100Tokenizer 41 | 42 | if is_torch_available(): 43 | from .modeling_m2m_100 import ( 44 | M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, 45 | M2M100ForConditionalGeneration, 46 | M2M100Model, 47 | M2M100PreTrainedModel, 48 | ) 49 | 50 | 51 | else: 52 | import importlib 53 | import os 54 | import sys 55 | 56 | class _LazyModule(_BaseLazyModule): 57 | """ 58 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 59 | """ 60 | 61 | __file__ = globals()["__file__"] 62 | __path__ = [os.path.dirname(__file__)] 63 | 64 | def _get_module(self, module_name: str): 65 | return importlib.import_module("." + module_name, self.__name__) 66 | 67 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 68 | -------------------------------------------------------------------------------- /mytransformers/models/marian/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import ( 21 | _BaseLazyModule, 22 | is_sentencepiece_available, 23 | is_tf_available, 24 | is_tokenizers_available, 25 | is_torch_available, 26 | ) 27 | 28 | 29 | _import_structure = { 30 | "configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig"], 31 | } 32 | 33 | if is_sentencepiece_available(): 34 | _import_structure["tokenization_marian"] = ["MarianTokenizer"] 35 | 36 | if is_torch_available(): 37 | _import_structure["modeling_marian"] = [ 38 | "MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST", 39 | "MarianForCausalLM", 40 | "MarianModel", 41 | "MarianMTModel", 42 | "MarianPreTrainedModel", 43 | ] 44 | 45 | if is_tf_available(): 46 | _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] 47 | 48 | 49 | if TYPE_CHECKING: 50 | from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig 51 | 52 | if is_sentencepiece_available(): 53 | from .tokenization_marian import MarianTokenizer 54 | 55 | if is_torch_available(): 56 | from .modeling_marian import ( 57 | MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST, 58 | MarianForCausalLM, 59 | MarianModel, 60 | MarianMTModel, 61 | MarianPreTrainedModel, 62 | ) 63 | 64 | if is_tf_available(): 65 | from .modeling_tf_marian import TFMarianModel, TFMarianMTModel 66 | 67 | else: 68 | import importlib 69 | import os 70 | import sys 71 | 72 | class _LazyModule(_BaseLazyModule): 73 | """ 74 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 75 | """ 76 | 77 | __file__ = globals()["__file__"] 78 | __path__ = [os.path.dirname(__file__)] 79 | 80 | def _get_module(self, module_name: str): 81 | return importlib.import_module("." + module_name, self.__name__) 82 | 83 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 84 | -------------------------------------------------------------------------------- /mytransformers/models/mmbt/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_mmbt": ["MMBTConfig"], 26 | } 27 | 28 | if is_torch_available(): 29 | _import_structure["modeling_mmbt"] = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"] 30 | 31 | 32 | if TYPE_CHECKING: 33 | from .configuration_mmbt import MMBTConfig 34 | 35 | if is_torch_available(): 36 | from .modeling_mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings 37 | 38 | else: 39 | import importlib 40 | import os 41 | import sys 42 | 43 | class _LazyModule(_BaseLazyModule): 44 | """ 45 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 46 | """ 47 | 48 | __file__ = globals()["__file__"] 49 | __path__ = [os.path.dirname(__file__)] 50 | 51 | def _get_module(self, module_name: str): 52 | return importlib.import_module("." + module_name, self.__name__) 53 | 54 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 55 | -------------------------------------------------------------------------------- /mytransformers/models/mmbt/configuration_mmbt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # Copyright (c) HuggingFace Inc. team. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ MMBT configuration """ 17 | 18 | from ...utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | class MMBTConfig(object): 25 | """ 26 | This is the configuration class to store the configuration of a :class:`~transformers.MMBTModel`. It is used to 27 | instantiate a MMBT model according to the specified arguments, defining the model architecture. 28 | 29 | Args: 30 | config (:class:`~transformers.PreTrainedConfig`): 31 | Config of the underlying Transformer models. Its values are copied over to use a single config. 32 | num_labels (:obj:`int`, `optional`): 33 | Size of final Linear layer for classification. 34 | modal_hidden_size (:obj:`int`, `optional`, defaults to 2048): 35 | Embedding dimension of the non-text modality encoder. 36 | """ 37 | 38 | def __init__(self, config, num_labels=None, modal_hidden_size=2048): 39 | self.__dict__ = config.__dict__ 40 | self.modal_hidden_size = modal_hidden_size 41 | if num_labels: 42 | self.num_labels = num_labels 43 | -------------------------------------------------------------------------------- /mytransformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | 17 | import torch 18 | 19 | from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert 20 | from transformers.utils import logging 21 | 22 | 23 | logging.set_verbosity_info() 24 | 25 | 26 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): 27 | # Initialise PyTorch model 28 | config = MobileBertConfig.from_json_file(mobilebert_config_file) 29 | print(f"Building PyTorch model from configuration: {config}") 30 | model = MobileBertForPreTraining(config) 31 | # Load weights from tf checkpoint 32 | model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path) 33 | # Save pytorch-model 34 | print(f"Save PyTorch model to {pytorch_dump_path}") 35 | torch.save(model.state_dict(), pytorch_dump_path) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | # Required parameters 41 | parser.add_argument( 42 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 43 | ) 44 | parser.add_argument( 45 | "--mobilebert_config_file", 46 | default=None, 47 | type=str, 48 | required=True, 49 | help="The config json file corresponding to the pre-trained MobileBERT model. \n" 50 | "This specifies the model architecture.", 51 | ) 52 | parser.add_argument( 53 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 54 | ) 55 | args = parser.parse_args() 56 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path) 57 | -------------------------------------------------------------------------------- /mytransformers/models/mobilebert/tokenization_mobilebert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright 2020 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Tokenization classes for MobileBERT.""" 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert import BertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": {"mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt"} 28 | } 29 | 30 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"mobilebert-uncased": 512} 31 | 32 | 33 | PRETRAINED_INIT_CONFIGURATION = {} 34 | 35 | 36 | class MobileBertTokenizer(BertTokenizer): 37 | r""" 38 | Construct a MobileBERT tokenizer. 39 | 40 | :class:`~transformers.MobileBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 41 | tokenization: punctuation splitting and wordpiece. 42 | 43 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 44 | parameters. 45 | """ 46 | 47 | vocab_files_names = VOCAB_FILES_NAMES 48 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 49 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 50 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 51 | -------------------------------------------------------------------------------- /mytransformers/models/mobilebert/tokenization_mobilebert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright 2020 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Tokenization classes for MobileBERT.""" 17 | 18 | from ...utils import logging 19 | from ..bert.tokenization_bert_fast import BertTokenizerFast 20 | from .tokenization_mobilebert import MobileBertTokenizer 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": {"mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt"}, 29 | "tokenizer_file": { 30 | "mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/tokenizer.json" 31 | }, 32 | } 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"mobilebert-uncased": 512} 35 | 36 | 37 | PRETRAINED_INIT_CONFIGURATION = {} 38 | 39 | 40 | class MobileBertTokenizerFast(BertTokenizerFast): 41 | r""" 42 | Construct a "fast" MobileBERT tokenizer (backed by HuggingFace's `tokenizers` library). 43 | 44 | :class:`~transformers.MobileBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 45 | end-to-end tokenization: punctuation splitting and wordpiece. 46 | 47 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 48 | parameters. 49 | """ 50 | 51 | vocab_files_names = VOCAB_FILES_NAMES 52 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 53 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 54 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 55 | slow_tokenizer_class = MobileBertTokenizer 56 | -------------------------------------------------------------------------------- /mytransformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt 23 | from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME 24 | from transformers.utils import logging 25 | 26 | 27 | logging.set_verbosity_info() 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 44 | print(f"Save PyTorch model to {pytorch_weights_dump_path}") 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print(f"Save configuration file to {pytorch_config_dump_path}") 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # Required parameters 54 | parser.add_argument( 55 | "--openai_checkpoint_folder_path", 56 | default=None, 57 | type=str, 58 | required=True, 59 | help="Path to the TensorFlow checkpoint path.", 60 | ) 61 | parser.add_argument( 62 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 63 | ) 64 | parser.add_argument( 65 | "--openai_config_file", 66 | default="", 67 | type=str, 68 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 69 | "This specifies the model architecture.", 70 | ) 71 | args = parser.parse_args() 72 | convert_openai_checkpoint_to_pytorch( 73 | args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path 74 | ) 75 | -------------------------------------------------------------------------------- /mytransformers/models/pegasus/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import ( 21 | _BaseLazyModule, 22 | is_sentencepiece_available, 23 | is_tf_available, 24 | is_tokenizers_available, 25 | is_torch_available, 26 | ) 27 | 28 | 29 | _import_structure = { 30 | "configuration_pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig"], 31 | } 32 | 33 | if is_sentencepiece_available(): 34 | _import_structure["tokenization_pegasus"] = ["PegasusTokenizer"] 35 | 36 | if is_tokenizers_available(): 37 | _import_structure["tokenization_pegasus_fast"] = ["PegasusTokenizerFast"] 38 | 39 | if is_torch_available(): 40 | _import_structure["modeling_pegasus"] = [ 41 | "PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST", 42 | "PegasusForCausalLM", 43 | "PegasusForConditionalGeneration", 44 | "PegasusModel", 45 | "PegasusPreTrainedModel", 46 | ] 47 | 48 | if is_tf_available(): 49 | _import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"] 50 | 51 | 52 | if TYPE_CHECKING: 53 | from .configuration_pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig 54 | 55 | if is_sentencepiece_available(): 56 | from .tokenization_pegasus import PegasusTokenizer 57 | 58 | if is_tokenizers_available(): 59 | from .tokenization_pegasus_fast import PegasusTokenizerFast 60 | 61 | if is_torch_available(): 62 | from .modeling_pegasus import ( 63 | PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST, 64 | PegasusForCausalLM, 65 | PegasusForConditionalGeneration, 66 | PegasusModel, 67 | PegasusPreTrainedModel, 68 | ) 69 | 70 | if is_tf_available(): 71 | from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel 72 | 73 | else: 74 | import importlib 75 | import os 76 | import sys 77 | 78 | class _LazyModule(_BaseLazyModule): 79 | """ 80 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 81 | """ 82 | 83 | __file__ = globals()["__file__"] 84 | __path__ = [os.path.dirname(__file__)] 85 | 86 | def _get_module(self, module_name: str): 87 | return importlib.import_module("." + module_name, self.__name__) 88 | 89 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 90 | -------------------------------------------------------------------------------- /mytransformers/models/phobert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule 22 | 23 | 24 | _import_structure = { 25 | "tokenization_phobert": ["PhobertTokenizer"], 26 | } 27 | 28 | 29 | if TYPE_CHECKING: 30 | from .tokenization_phobert import PhobertTokenizer 31 | 32 | else: 33 | import importlib 34 | import os 35 | import sys 36 | 37 | class _LazyModule(_BaseLazyModule): 38 | """ 39 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 40 | """ 41 | 42 | __file__ = globals()["__file__"] 43 | __path__ = [os.path.dirname(__file__)] 44 | 45 | def _get_module(self, module_name: str): 46 | return importlib.import_module("." + module_name, self.__name__) 47 | 48 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 49 | -------------------------------------------------------------------------------- /mytransformers/models/prophetnet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig"], 26 | "tokenization_prophetnet": ["ProphetNetTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_prophetnet"] = [ 31 | "PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "ProphetNetDecoder", 33 | "ProphetNetEncoder", 34 | "ProphetNetForCausalLM", 35 | "ProphetNetForConditionalGeneration", 36 | "ProphetNetModel", 37 | "ProphetNetPreTrainedModel", 38 | ] 39 | 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig 43 | from .tokenization_prophetnet import ProphetNetTokenizer 44 | 45 | if is_torch_available(): 46 | from .modeling_prophetnet import ( 47 | PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | ProphetNetDecoder, 49 | ProphetNetEncoder, 50 | ProphetNetForCausalLM, 51 | ProphetNetForConditionalGeneration, 52 | ProphetNetModel, 53 | ProphetNetPreTrainedModel, 54 | ) 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /mytransformers/models/rag/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_rag": ["RagConfig"], 26 | "retrieval_rag": ["RagRetriever"], 27 | "tokenization_rag": ["RagTokenizer"], 28 | } 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] 32 | 33 | if is_tf_available(): 34 | _import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] 35 | 36 | 37 | if TYPE_CHECKING: 38 | from .configuration_rag import RagConfig 39 | from .retrieval_rag import RagRetriever 40 | from .tokenization_rag import RagTokenizer 41 | 42 | if is_torch_available(): 43 | from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration 44 | 45 | if is_tf_available(): 46 | from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration 47 | 48 | else: 49 | import importlib 50 | import os 51 | import sys 52 | 53 | class _LazyModule(_BaseLazyModule): 54 | """ 55 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 56 | """ 57 | 58 | __file__ = globals()["__file__"] 59 | __path__ = [os.path.dirname(__file__)] 60 | 61 | def _get_module(self, module_name: str): 62 | return importlib.import_module("." + module_name, self.__name__) 63 | 64 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 65 | -------------------------------------------------------------------------------- /mytransformers/models/reformer/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], 26 | } 27 | 28 | if is_sentencepiece_available(): 29 | _import_structure["tokenization_reformer"] = ["ReformerTokenizer"] 30 | 31 | if is_tokenizers_available(): 32 | _import_structure["tokenization_reformer_fast"] = ["ReformerTokenizerFast"] 33 | 34 | if is_torch_available(): 35 | _import_structure["modeling_reformer"] = [ 36 | "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", 37 | "ReformerAttention", 38 | "ReformerForMaskedLM", 39 | "ReformerForQuestionAnswering", 40 | "ReformerForSequenceClassification", 41 | "ReformerLayer", 42 | "ReformerModel", 43 | "ReformerModelWithLMHead", 44 | ] 45 | 46 | 47 | if TYPE_CHECKING: 48 | from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig 49 | 50 | if is_sentencepiece_available(): 51 | from .tokenization_reformer import ReformerTokenizer 52 | 53 | if is_tokenizers_available(): 54 | from .tokenization_reformer_fast import ReformerTokenizerFast 55 | 56 | if is_torch_available(): 57 | from .modeling_reformer import ( 58 | REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, 59 | ReformerAttention, 60 | ReformerForMaskedLM, 61 | ReformerForQuestionAnswering, 62 | ReformerForSequenceClassification, 63 | ReformerLayer, 64 | ReformerModel, 65 | ReformerModelWithLMHead, 66 | ) 67 | 68 | else: 69 | import importlib 70 | import os 71 | import sys 72 | 73 | class _LazyModule(_BaseLazyModule): 74 | """ 75 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 76 | """ 77 | 78 | __file__ = globals()["__file__"] 79 | __path__ = [os.path.dirname(__file__)] 80 | 81 | def _get_module(self, module_name: str): 82 | return importlib.import_module("." + module_name, self.__name__) 83 | 84 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 85 | -------------------------------------------------------------------------------- /mytransformers/models/retribert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig"], 26 | "tokenization_retribert": ["RetriBertTokenizer"], 27 | } 28 | 29 | if is_tokenizers_available(): 30 | _import_structure["tokenization_retribert_fast"] = ["RetriBertTokenizerFast"] 31 | 32 | if is_torch_available(): 33 | _import_structure["modeling_retribert"] = [ 34 | "RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", 35 | "RetriBertModel", 36 | "RetriBertPreTrainedModel", 37 | ] 38 | 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig 42 | from .tokenization_retribert import RetriBertTokenizer 43 | 44 | if is_tokenizers_available(): 45 | from .tokenization_retribert_fast import RetriBertTokenizerFast 46 | 47 | if is_torch_available(): 48 | from .modeling_retribert import ( 49 | RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, 50 | RetriBertModel, 51 | RetriBertPreTrainedModel, 52 | ) 53 | 54 | else: 55 | import importlib 56 | import os 57 | import sys 58 | 59 | class _LazyModule(_BaseLazyModule): 60 | """ 61 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 62 | """ 63 | 64 | __file__ = globals()["__file__"] 65 | __path__ = [os.path.dirname(__file__)] 66 | 67 | def _get_module(self, module_name: str): 68 | return importlib.import_module("." + module_name, self.__name__) 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 71 | -------------------------------------------------------------------------------- /mytransformers/models/retribert/tokenization_retribert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RetriBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert import BertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt", 28 | } 29 | } 30 | 31 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 32 | "yjernite/retribert-base-uncased": 512, 33 | } 34 | 35 | 36 | PRETRAINED_INIT_CONFIGURATION = { 37 | "yjernite/retribert-base-uncased": {"do_lower_case": True}, 38 | } 39 | 40 | 41 | class RetriBertTokenizer(BertTokenizer): 42 | r""" 43 | Constructs a RetriBERT tokenizer. 44 | 45 | :class:`~transformers.RetroBertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 46 | tokenization: punctuation splitting and wordpiece. 47 | 48 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 49 | parameters. 50 | """ 51 | 52 | vocab_files_names = VOCAB_FILES_NAMES 53 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 54 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 55 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 56 | model_input_names = ["input_ids", "attention_mask"] 57 | -------------------------------------------------------------------------------- /mytransformers/models/retribert/tokenization_retribert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RetriBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert_fast import BertTokenizerFast 19 | from .tokenization_retribert import RetriBertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt", 29 | }, 30 | "tokenizer_file": { 31 | "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/tokenizer.json", 32 | }, 33 | } 34 | 35 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 36 | "yjernite/retribert-base-uncased": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "yjernite/retribert-base-uncased": {"do_lower_case": True}, 42 | } 43 | 44 | 45 | class RetriBertTokenizerFast(BertTokenizerFast): 46 | r""" 47 | Construct a "fast" RetriBERT tokenizer (backed by HuggingFace's `tokenizers` library). 48 | 49 | :class:`~transformers.RetriBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 50 | end-to-end tokenization: punctuation splitting and wordpiece. 51 | 52 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 53 | parameters. 54 | """ 55 | 56 | vocab_files_names = VOCAB_FILES_NAMES 57 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 58 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 59 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 60 | slow_tokenizer_class = RetriBertTokenizer 61 | model_input_names = ["input_ids", "attention_mask"] 62 | -------------------------------------------------------------------------------- /mytransformers/models/roberta/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from ...utils import logging 19 | from ..bert.configuration_bert import BertConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json", 26 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json", 27 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json", 28 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json", 29 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json", 30 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json", 31 | } 32 | 33 | 34 | class RobertaConfig(BertConfig): 35 | r""" 36 | This is the configuration class to store the configuration of a :class:`~transformers.RobertaModel` or a 37 | :class:`~transformers.TFRobertaModel`. It is used to instantiate a RoBERTa model according to the specified 38 | arguments, defining the model architecture. 39 | 40 | 41 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 42 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 43 | 44 | The :class:`~transformers.RobertaConfig` class directly inherits :class:`~transformers.BertConfig`. It reuses the 45 | same defaults. Please check the parent class for more information. 46 | 47 | Examples:: 48 | 49 | >>> from transformers import RobertaConfig, RobertaModel 50 | 51 | >>> # Initializing a RoBERTa configuration 52 | >>> configuration = RobertaConfig() 53 | 54 | >>> # Initializing a model from the configuration 55 | >>> model = RobertaModel(configuration) 56 | 57 | >>> # Accessing the model configuration 58 | >>> configuration = model.config 59 | """ 60 | model_type = "roberta" 61 | 62 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): 63 | """Constructs RobertaConfig.""" 64 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 65 | -------------------------------------------------------------------------------- /mytransformers/models/speech_to_text/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_speech_to_text": [ 25 | "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", 26 | "Speech2TextConfig", 27 | ], 28 | "feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"], 29 | } 30 | 31 | if is_sentencepiece_available(): 32 | _import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"] 33 | _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] 34 | 35 | if is_torch_available(): 36 | _import_structure["modeling_speech_to_text"] = [ 37 | "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", 38 | "Speech2TextForConditionalGeneration", 39 | "Speech2TextModel", 40 | "Speech2TextPreTrainedModel", 41 | ] 42 | 43 | 44 | if TYPE_CHECKING: 45 | from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig 46 | from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor 47 | 48 | if is_sentencepiece_available(): 49 | from .processing_speech_to_text import Speech2TextProcessor 50 | from .tokenization_speech_to_text import Speech2TextTokenizer 51 | 52 | if is_torch_available(): 53 | from .modeling_speech_to_text import ( 54 | SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, 55 | Speech2TextForConditionalGeneration, 56 | Speech2TextModel, 57 | Speech2TextPreTrainedModel, 58 | ) 59 | 60 | else: 61 | import importlib 62 | import os 63 | import sys 64 | 65 | class _LazyModule(_BaseLazyModule): 66 | """ 67 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 68 | """ 69 | 70 | __file__ = globals()["__file__"] 71 | __path__ = [os.path.dirname(__file__)] 72 | 73 | def _get_module(self, module_name: str): 74 | return importlib.import_module("." + module_name, self.__name__) 75 | 76 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 77 | -------------------------------------------------------------------------------- /mytransformers/models/squeezebert/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"], 26 | "tokenization_squeezebert": ["SqueezeBertTokenizer"], 27 | } 28 | 29 | if is_tokenizers_available(): 30 | _import_structure["tokenization_squeezebert_fast"] = ["SqueezeBertTokenizerFast"] 31 | 32 | if is_torch_available(): 33 | _import_structure["modeling_squeezebert"] = [ 34 | "SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", 35 | "SqueezeBertForMaskedLM", 36 | "SqueezeBertForMultipleChoice", 37 | "SqueezeBertForQuestionAnswering", 38 | "SqueezeBertForSequenceClassification", 39 | "SqueezeBertForTokenClassification", 40 | "SqueezeBertModel", 41 | "SqueezeBertModule", 42 | "SqueezeBertPreTrainedModel", 43 | ] 44 | 45 | 46 | if TYPE_CHECKING: 47 | from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig 48 | from .tokenization_squeezebert import SqueezeBertTokenizer 49 | 50 | if is_tokenizers_available(): 51 | from .tokenization_squeezebert_fast import SqueezeBertTokenizerFast 52 | 53 | if is_torch_available(): 54 | from .modeling_squeezebert import ( 55 | SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, 56 | SqueezeBertForMaskedLM, 57 | SqueezeBertForMultipleChoice, 58 | SqueezeBertForQuestionAnswering, 59 | SqueezeBertForSequenceClassification, 60 | SqueezeBertForTokenClassification, 61 | SqueezeBertModel, 62 | SqueezeBertModule, 63 | SqueezeBertPreTrainedModel, 64 | ) 65 | 66 | else: 67 | import importlib 68 | import os 69 | import sys 70 | 71 | class _LazyModule(_BaseLazyModule): 72 | """ 73 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 74 | """ 75 | 76 | __file__ = globals()["__file__"] 77 | __path__ = [os.path.dirname(__file__)] 78 | 79 | def _get_module(self, module_name: str): 80 | return importlib.import_module("." + module_name, self.__name__) 81 | 82 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 83 | -------------------------------------------------------------------------------- /mytransformers/models/squeezebert/tokenization_squeezebert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for SqueezeBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert import BertTokenizer 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 24 | 25 | PRETRAINED_VOCAB_FILES_MAP = { 26 | "vocab_file": { 27 | "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt", 28 | "squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt", 29 | "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt", 30 | } 31 | } 32 | 33 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 34 | "squeezebert/squeezebert-uncased": 512, 35 | "squeezebert/squeezebert-mnli": 512, 36 | "squeezebert/squeezebert-mnli-headless": 512, 37 | } 38 | 39 | 40 | PRETRAINED_INIT_CONFIGURATION = { 41 | "squeezebert/squeezebert-uncased": {"do_lower_case": True}, 42 | "squeezebert/squeezebert-mnli": {"do_lower_case": True}, 43 | "squeezebert/squeezebert-mnli-headless": {"do_lower_case": True}, 44 | } 45 | 46 | 47 | class SqueezeBertTokenizer(BertTokenizer): 48 | r""" 49 | Constructs a SqueezeBert tokenizer. 50 | 51 | :class:`~transformers.SqueezeBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end 52 | tokenization: punctuation splitting + wordpiece. 53 | 54 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning 55 | parameters. 56 | """ 57 | 58 | vocab_files_names = VOCAB_FILES_NAMES 59 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 60 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 61 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 62 | -------------------------------------------------------------------------------- /mytransformers/models/squeezebert/tokenization_squeezebert_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for SqueezeBERT.""" 16 | 17 | from ...utils import logging 18 | from ..bert.tokenization_bert_fast import BertTokenizerFast 19 | from .tokenization_squeezebert import SqueezeBertTokenizer 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt", 29 | "squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt", 30 | "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt", 31 | }, 32 | "tokenizer_file": { 33 | "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/tokenizer.json", 34 | "squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/tokenizer.json", 35 | "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/tokenizer.json", 36 | }, 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | "squeezebert/squeezebert-uncased": 512, 41 | "squeezebert/squeezebert-mnli": 512, 42 | "squeezebert/squeezebert-mnli-headless": 512, 43 | } 44 | 45 | 46 | PRETRAINED_INIT_CONFIGURATION = { 47 | "squeezebert/squeezebert-uncased": {"do_lower_case": True}, 48 | "squeezebert/squeezebert-mnli": {"do_lower_case": True}, 49 | "squeezebert/squeezebert-mnli-headless": {"do_lower_case": True}, 50 | } 51 | 52 | 53 | class SqueezeBertTokenizerFast(BertTokenizerFast): 54 | r""" 55 | Constructs a "Fast" SqueezeBert tokenizer (backed by HuggingFace's `tokenizers` library). 56 | 57 | :class:`~transformers.SqueezeBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs 58 | end-to-end tokenization: punctuation splitting + wordpiece. 59 | 60 | Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning 61 | parameters. 62 | """ 63 | 64 | vocab_files_names = VOCAB_FILES_NAMES 65 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 66 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 67 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 68 | slow_tokenizer_class = SqueezeBertTokenizer 69 | -------------------------------------------------------------------------------- /mytransformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The T5 authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert T5 checkpoint.""" 16 | 17 | 18 | import argparse 19 | 20 | from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 21 | from transformers.utils import logging 22 | 23 | 24 | logging.set_verbosity_info() 25 | 26 | 27 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): 28 | # Initialise PyTorch model 29 | config = T5Config.from_json_file(config_file) 30 | print(f"Building PyTorch model from configuration: {config}") 31 | model = T5ForConditionalGeneration(config) 32 | 33 | # Load weights from tf checkpoint 34 | load_tf_weights_in_t5(model, config, tf_checkpoint_path) 35 | 36 | # Save pytorch-model 37 | print(f"Save PyTorch model to {pytorch_dump_path}") 38 | model.save_pretrained(pytorch_dump_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | # Required parameters 44 | parser.add_argument( 45 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." 46 | ) 47 | parser.add_argument( 48 | "--config_file", 49 | default=None, 50 | type=str, 51 | required=True, 52 | help="The config json file corresponding to the pre-trained T5 model. \n" 53 | "This specifies the model architecture.", 54 | ) 55 | parser.add_argument( 56 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 57 | ) 58 | args = parser.parse_args() 59 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) 60 | -------------------------------------------------------------------------------- /mytransformers/models/tapas/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from typing import TYPE_CHECKING 20 | 21 | from ...file_utils import _BaseLazyModule, is_torch_available 22 | 23 | 24 | _import_structure = { 25 | "configuration_tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig"], 26 | "tokenization_tapas": ["TapasTokenizer"], 27 | } 28 | 29 | if is_torch_available(): 30 | _import_structure["modeling_tapas"] = [ 31 | "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", 32 | "TapasForMaskedLM", 33 | "TapasForQuestionAnswering", 34 | "TapasForSequenceClassification", 35 | "TapasModel", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig 41 | from .tokenization_tapas import TapasTokenizer 42 | 43 | if is_torch_available(): 44 | from .modeling_tapas import ( 45 | TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, 46 | TapasForMaskedLM, 47 | TapasForQuestionAnswering, 48 | TapasForSequenceClassification, 49 | TapasModel, 50 | ) 51 | 52 | else: 53 | import importlib 54 | import os 55 | import sys 56 | 57 | class _LazyModule(_BaseLazyModule): 58 | """ 59 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 60 | """ 61 | 62 | __file__ = globals()["__file__"] 63 | __path__ = [os.path.dirname(__file__)] 64 | 65 | def _get_module(self, module_name: str): 66 | return importlib.import_module("." + module_name, self.__name__) 67 | 68 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 69 | -------------------------------------------------------------------------------- /mytransformers/models/vit/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], 25 | } 26 | 27 | if is_vision_available(): 28 | _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"] 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_vit"] = [ 32 | "VIT_PRETRAINED_MODEL_ARCHIVE_LIST", 33 | "ViTForImageClassification", 34 | "ViTModel", 35 | "ViTPreTrainedModel", 36 | ] 37 | 38 | 39 | if TYPE_CHECKING: 40 | from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig 41 | 42 | if is_vision_available(): 43 | from .feature_extraction_vit import ViTFeatureExtractor 44 | 45 | if is_torch_available(): 46 | from .modeling_vit import ( 47 | VIT_PRETRAINED_MODEL_ARCHIVE_LIST, 48 | ViTForImageClassification, 49 | ViTModel, 50 | ViTPreTrainedModel, 51 | ) 52 | 53 | 54 | else: 55 | import importlib 56 | import os 57 | import sys 58 | 59 | class _LazyModule(_BaseLazyModule): 60 | """ 61 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 62 | """ 63 | 64 | __file__ = globals()["__file__"] 65 | __path__ = [os.path.dirname(__file__)] 66 | 67 | def _get_module(self, module_name: str): 68 | return importlib.import_module("." + module_name, self.__name__) 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 71 | -------------------------------------------------------------------------------- /mytransformers/models/wav2vec2/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available 21 | 22 | 23 | _import_structure = { 24 | "configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"], 25 | "feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"], 26 | "processing_wav2vec2": ["Wav2Vec2Processor"], 27 | "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], 28 | } 29 | 30 | if is_torch_available(): 31 | _import_structure["modeling_wav2vec2"] = [ 32 | "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", 33 | "Wav2Vec2ForCTC", 34 | "Wav2Vec2ForMaskedLM", 35 | "Wav2Vec2Model", 36 | "Wav2Vec2PreTrainedModel", 37 | ] 38 | 39 | 40 | if TYPE_CHECKING: 41 | from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config 42 | from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor 43 | from .processing_wav2vec2 import Wav2Vec2Processor 44 | from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer 45 | 46 | if is_torch_available(): 47 | from .modeling_wav2vec2 import ( 48 | WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, 49 | Wav2Vec2ForCTC, 50 | Wav2Vec2ForMaskedLM, 51 | Wav2Vec2Model, 52 | Wav2Vec2PreTrainedModel, 53 | ) 54 | 55 | 56 | else: 57 | import importlib 58 | import os 59 | import sys 60 | 61 | class _LazyModule(_BaseLazyModule): 62 | """ 63 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 64 | """ 65 | 66 | __file__ = globals()["__file__"] 67 | __path__ = [os.path.dirname(__file__)] 68 | 69 | def _get_module(self, module_name: str): 70 | return importlib.import_module("." + module_name, self.__name__) 71 | 72 | sys.modules[__name__] = _LazyModule(__name__, _import_structure) 73 | -------------------------------------------------------------------------------- /mytransformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | 18 | import argparse 19 | import json 20 | 21 | import numpy 22 | import torch 23 | 24 | from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME 25 | from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES 26 | from transformers.utils import logging 27 | 28 | 29 | logging.set_verbosity_info() 30 | 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location="cpu") 35 | 36 | state_dict = chkpt["model"] 37 | 38 | # We have the base model one level deeper than the original XLM repository 39 | two_levels_state_dict = {} 40 | for k, v in state_dict.items(): 41 | if "pred_layer" in k: 42 | two_levels_state_dict[k] = v 43 | else: 44 | two_levels_state_dict["transformer." + k] = v 45 | 46 | config = chkpt["params"] 47 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 48 | 49 | vocab = chkpt["dico_word2id"] 50 | vocab = dict((s + "" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items()) 51 | 52 | # Save pytorch-model 53 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 54 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"] 56 | 57 | print(f"Save PyTorch model to {pytorch_weights_dump_path}") 58 | torch.save(two_levels_state_dict, pytorch_weights_dump_path) 59 | 60 | print(f"Save configuration file to {pytorch_config_dump_path}") 61 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 62 | f.write(json.dumps(config, indent=2) + "\n") 63 | 64 | print(f"Save vocab file to {pytorch_config_dump_path}") 65 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 66 | f.write(json.dumps(vocab, indent=2) + "\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | # Required parameters 72 | parser.add_argument( 73 | "--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." 74 | ) 75 | parser.add_argument( 76 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." 77 | ) 78 | args = parser.parse_args() 79 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 80 | -------------------------------------------------------------------------------- /mytransformers/models/xlm_prophetnet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2020 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from ...file_utils import is_sentencepiece_available, is_torch_available 20 | from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig 21 | 22 | 23 | if is_sentencepiece_available(): 24 | from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer 25 | 26 | if is_torch_available(): 27 | from .modeling_xlm_prophetnet import ( 28 | XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, 29 | XLMProphetNetDecoder, 30 | XLMProphetNetEncoder, 31 | XLMProphetNetForCausalLM, 32 | XLMProphetNetForConditionalGeneration, 33 | XLMProphetNetModel, 34 | ) 35 | -------------------------------------------------------------------------------- /mytransformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ XLM-ProphetNet model configuration """ 16 | 17 | 18 | from ...utils import logging 19 | from ..prophetnet.configuration_prophetnet import ProphetNetConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "microsoft/xprophetnet-large-wiki100-cased": "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json", 26 | } 27 | 28 | 29 | class XLMProphetNetConfig(ProphetNetConfig): 30 | """ 31 | This class overrides :class:`~transformers.ProphetNetConfig`. Please check the superclass for the appropriate 32 | documentation alongside usage examples. 33 | """ 34 | 35 | model_type = "xlm-prophetnet" 36 | -------------------------------------------------------------------------------- /mytransformers/models/xlm_roberta/configuration_xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XLM-RoBERTa configuration """ 17 | 18 | from ...utils import logging 19 | from ..roberta.configuration_roberta import RobertaConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/config.json", 26 | "xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/config.json", 27 | "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json", 28 | "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json", 29 | "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json", 30 | "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json", 31 | } 32 | 33 | 34 | class XLMRobertaConfig(RobertaConfig): 35 | """ 36 | This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate 37 | documentation alongside usage examples. 38 | """ 39 | 40 | model_type = "xlm-roberta" 41 | -------------------------------------------------------------------------------- /mytransformers/sagemaker/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from .trainer_sm import SageMakerTrainer 20 | from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled 21 | -------------------------------------------------------------------------------- /mytransformers/training_args_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from dataclasses import dataclass, field 17 | 18 | from .file_utils import add_start_docstrings 19 | from .training_args import TrainingArguments 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @dataclass 26 | @add_start_docstrings(TrainingArguments.__doc__) 27 | class Seq2SeqTrainingArguments(TrainingArguments): 28 | """ 29 | sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): 30 | Whether to use a `sortish sampler` or not. Only possible if the underlying datasets are `Seq2SeqDataset` for 31 | now but will become generally available in the near future. 32 | 33 | It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness for 34 | the training set. 35 | predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): 36 | Whether to use generate to calculate generative metrics (ROUGE, BLEU). 37 | """ 38 | 39 | sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) 40 | predict_with_generate: bool = field( 41 | default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} 42 | ) 43 | -------------------------------------------------------------------------------- /mytransformers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from packaging import version 18 | 19 | from .. import __hf_version__ 20 | 21 | 22 | def check_min_version(min_version): 23 | if version.parse(__hf_version__) < version.parse(min_version): 24 | error_message = ( 25 | f"This example requires a minimum underlying HuggingFace Transformers version of {min_version}," 26 | ) 27 | error_message += f" but the version found is {__hf_version__}.\n" 28 | raise ImportError(error_message) 29 | -------------------------------------------------------------------------------- /mytransformers/utils/dummy_vision_objects.py: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by the command `make fix-copies`, do not edit. 2 | from ..file_utils import requires_vision 3 | 4 | 5 | class ImageFeatureExtractionMixin: 6 | def __init__(self, *args, **kwargs): 7 | requires_vision(self) 8 | 9 | 10 | class ViTFeatureExtractor: 11 | def __init__(self, *args, **kwargs): 12 | requires_vision(self) 13 | -------------------------------------------------------------------------------- /mytransformers/utils/model_parallel_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from math import ceil 17 | 18 | 19 | def assert_device_map(device_map, num_blocks): 20 | blocks = list(range(0, num_blocks)) 21 | 22 | device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] 23 | 24 | # Duplicate check 25 | duplicate_blocks = [] 26 | for i in device_map_blocks: 27 | if device_map_blocks.count(i) > 1 and i not in duplicate_blocks: 28 | duplicate_blocks.append(i) 29 | # Missing blocks 30 | missing_blocks = [i for i in blocks if i not in device_map_blocks] 31 | extra_blocks = [i for i in device_map_blocks if i not in blocks] 32 | 33 | assert len(duplicate_blocks) == 0, ( 34 | "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These " 35 | "attention blocks were specified more than once: " + str(duplicate_blocks) 36 | ) 37 | assert len(missing_blocks) == 0, ( 38 | "There are attention blocks for this model that are not specified in the device_map. Add these attention " 39 | "blocks to a device on the device_map: " + str(missing_blocks) 40 | ) 41 | assert ( 42 | len(extra_blocks) == 0 43 | ), "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str( 44 | extra_blocks 45 | ) 46 | 47 | 48 | def get_device_map(n_layers, devices): 49 | """Returns a dictionary of layers distributed evenly across all devices.""" 50 | layers = list(range(n_layers)) 51 | n_blocks = int(ceil(n_layers / len(devices))) 52 | layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) 53 | 54 | return dict(zip(devices, layers_list)) 55 | -------------------------------------------------------------------------------- /mytransformers/utils/modeling_auto_mapping.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify: models/auto/modeling_auto.py 3 | # 2. run: python utils/class_mapping_update.py 4 | from collections import OrderedDict 5 | 6 | 7 | MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( 8 | [ 9 | ("BigBirdConfig", "BigBirdForQuestionAnswering"), 10 | ("ConvBertConfig", "ConvBertForQuestionAnswering"), 11 | ("LEDConfig", "LEDForQuestionAnswering"), 12 | ("DistilBertConfig", "DistilBertForQuestionAnswering"), 13 | ("AlbertConfig", "AlbertForQuestionAnswering"), 14 | ("CamembertConfig", "CamembertForQuestionAnswering"), 15 | ("BartConfig", "BartForQuestionAnswering"), 16 | ("MBartConfig", "MBartForQuestionAnswering"), 17 | ("LongformerConfig", "LongformerForQuestionAnswering"), 18 | ("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"), 19 | ("RobertaConfig", "RobertaForQuestionAnswering"), 20 | ("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"), 21 | ("BertConfig", "BertForQuestionAnswering"), 22 | ("XLNetConfig", "XLNetForQuestionAnsweringSimple"), 23 | ("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"), 24 | ("MobileBertConfig", "MobileBertForQuestionAnswering"), 25 | ("XLMConfig", "XLMForQuestionAnsweringSimple"), 26 | ("ElectraConfig", "ElectraForQuestionAnswering"), 27 | ("ReformerConfig", "ReformerForQuestionAnswering"), 28 | ("FunnelConfig", "FunnelForQuestionAnswering"), 29 | ("LxmertConfig", "LxmertForQuestionAnswering"), 30 | ("MPNetConfig", "MPNetForQuestionAnswering"), 31 | ("DebertaConfig", "DebertaForQuestionAnswering"), 32 | ("DebertaV2Config", "DebertaV2ForQuestionAnswering"), 33 | ("IBertConfig", "IBertForQuestionAnswering"), 34 | ] 35 | ) 36 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from utils_lll import create_dataloader, QADataset 2 | from settings_lll import args, TASK_DICT 3 | import pickle as pkl 4 | import re 5 | import os 6 | import json 7 | from multiprocessing import Pool 8 | 9 | 10 | def serialize_data(redo=True): 11 | print("serializing data ...") 12 | for t in ["train", "eval", "test"]: 13 | for task in TASK_DICT.keys(): 14 | data_path = TASK_DICT[task][t] 15 | pkl_path = re.sub("json","pkl", data_path) 16 | if os.path.exists(pkl_path) and not redo: 17 | continue 18 | dataset = QADataset(data_path, t) 19 | with open(pkl_path, "wb") as f: 20 | pkl.dump(dataset,f) 21 | print("data serialized!") 22 | 23 | 24 | def dump_data_attrs(task): 25 | attrs = {task:{"train":{}, "eval":{}, "test":{}}} 26 | for t in ["train", "eval", "test"]: 27 | print(task,t) 28 | data_path = TASK_DICT[task][t] 29 | pkl_path = re.sub("json","pkl", data_path) 30 | with open(pkl_path, "rb") as f: 31 | dataset = pkl.load(f) 32 | attrs[task][t] = {"data_size": len(dataset), 33 | "max_a_len": dataset.max_a_len, 34 | } 35 | return attrs 36 | 37 | 38 | def parallel_dump_data_attrs(tasks=TASK_DICT.keys()): 39 | print("creating data_attrs.json ...") 40 | attr_dict = {} 41 | with Pool(args.n_workers) as pool: 42 | attrs = pool.map(dump_data_attrs, tasks) 43 | for a in attrs: 44 | attr_dict.update(a) 45 | with open("data_attrs.json","w") as f: 46 | json.dump(attr_dict,f) 47 | print("data_attrs.json created!") 48 | 49 | 50 | serialize_data() 51 | parallel_dump_data_attrs() 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GPUtil==1.4.0 2 | quadprog==0.1.8 3 | sacrebleu==1.5.0 4 | nltk==3.5 5 | jsonlines==2.0.0 6 | adapter_transformers==2.0.0 7 | torch==1.9.0 -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch DataLoader for TFRecords""" 16 | 17 | import torch 18 | from torch.optim.lr_scheduler import _LRScheduler 19 | import math 20 | 21 | class AnnealingLR(_LRScheduler): 22 | """Anneals the learning rate from start to zero along a cosine curve.""" 23 | 24 | DECAY_STYLES = ['linear', 'cos', 'exp', 'const', 'None'] 25 | 26 | def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1): 27 | self.optimizer = optimizer 28 | self.start_lr = start_lr 29 | self.warmup_iter = warmup_iter 30 | self.num_iters = last_iter + 1 31 | self.end_iter = num_iters 32 | self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None 33 | self.step(self.num_iters) 34 | 35 | def get_lr(self): 36 | # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 37 | if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: 38 | return float(self.start_lr) * self.num_iters / self.warmup_iter 39 | else: 40 | if self.decay_style == self.DECAY_STYLES[0]: 41 | return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter) 42 | elif self.decay_style == self.DECAY_STYLES[1]: 43 | return self.start_lr / 2.0 * (math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1) 44 | elif self.decay_style == self.DECAY_STYLES[2]: 45 | raise NotImplementedError("Exponential decay not yet implemented") 46 | else: 47 | return self.start_lr 48 | 49 | def step(self, step_num=None): 50 | if step_num is None: 51 | step_num = self.num_iters + 1 52 | self.num_iters = step_num 53 | new_lr = self.get_lr() 54 | for group in self.optimizer.param_groups: 55 | group['lr'] = new_lr 56 | 57 | def state_dict(self): 58 | sd = { 59 | 'start_lr': self.start_lr, 60 | 'warmup_iter': self.warmup_iter, 61 | 'num_iters': self.num_iters, 62 | 'decay_style': self.decay_style, 63 | 'end_iter': self.end_iter 64 | } 65 | return sd 66 | 67 | def load_state_dict(self, sd): 68 | self.start_lr = sd['start_lr'] 69 | self.warmup_iter = sd['warmup_iter'] 70 | self.num_iters = sd['num_iters'] 71 | self.end_iter = sd['end_iter'] 72 | self.decay_style = sd['decay_style'] 73 | self.step(self.num_iters) 74 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -i 2 | 3 | source ./environment 4 | 5 | python test.py \ 6 | --data_dir $DATA_DIR \ 7 | --model_dir_root $MODEL_ROOT_DIR \ 8 | "$@" 9 | -------------------------------------------------------------------------------- /test_myadaptor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -i 2 | 3 | source ./environment 4 | 5 | python test_myadaptor.py \ 6 | --data_dir $DATA_DIR \ 7 | --model_dir_root $MODEL_ROOT_DIR \ 8 | "$@" 9 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -i 2 | 3 | source ./environment 4 | 5 | python train.py \ 6 | --data_dir $DATA_DIR \ 7 | --model_dir_root $MODEL_ROOT_DIR \ 8 | "$@" 9 | -------------------------------------------------------------------------------- /train_myadaptor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -i 2 | 3 | source ./environment 4 | 5 | python train_myadaptor.py \ 6 | --data_dir $DATA_DIR \ 7 | --model_dir_root $MODEL_ROOT_DIR \ 8 | "$@" 9 | -------------------------------------------------------------------------------- /v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SALT-NLP/Adaptive-Compositional-Modules/357aa2d6d1cd97ea03aeaddbd5372a1aeecbbe4c/v2.gif --------------------------------------------------------------------------------